1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17 #include <deque>
18 #include <optional>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/cleanup/cleanup.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
31 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
32 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
33 #include "tensorflow/compiler/xla/service/shape_inference.h"
34 #include "tensorflow/compiler/xla/service/sharding_propagation.h"
35 #include "tensorflow/compiler/xla/service/spmd/convolution_handler.h"
36 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
37 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/status.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/compiler/xla/window_util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43
44 namespace xla {
45 namespace spmd {
46
47 namespace {
48 using hlo_sharding_util::GroupedSharding;
49 } // namespace
50
HandleDot(HloInstruction * hlo)51 Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) {
52 DotConvDimsMapping mapping;
53 const auto& dnums = hlo->dot_dimension_numbers();
54 int64_t next_output_dim = 0;
55 for (int64_t i = 0; i < dnums.lhs_batch_dimensions_size(); ++i) {
56 mapping.batch_dims.emplace_back();
57 mapping.batch_dims.back().lhs = dnums.lhs_batch_dimensions(i);
58 mapping.batch_dims.back().rhs = dnums.rhs_batch_dimensions(i);
59 mapping.batch_dims.back().output = next_output_dim++;
60 }
61 for (int64_t i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) {
62 mapping.contracting_dims.emplace_back();
63 mapping.contracting_dims.back().lhs = dnums.lhs_contracting_dimensions(i);
64 mapping.contracting_dims.back().rhs = dnums.rhs_contracting_dimensions(i);
65 mapping.contracting_dims.back().output = -1;
66 }
67 for (int64_t i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
68 if (absl::c_linear_search(dnums.lhs_batch_dimensions(), i) ||
69 absl::c_linear_search(dnums.lhs_contracting_dimensions(), i)) {
70 continue;
71 }
72 mapping.lhs_non_contracting_dims.emplace_back();
73 mapping.lhs_non_contracting_dims.back().lhs = i;
74 mapping.lhs_non_contracting_dims.back().rhs = -1;
75 mapping.lhs_non_contracting_dims.back().output = next_output_dim++;
76 }
77 for (int64_t i = 0; i < hlo->operand(1)->shape().rank(); ++i) {
78 if (absl::c_linear_search(dnums.rhs_batch_dimensions(), i) ||
79 absl::c_linear_search(dnums.rhs_contracting_dimensions(), i)) {
80 continue;
81 }
82 mapping.rhs_non_contracting_dims.emplace_back();
83 mapping.rhs_non_contracting_dims.back().lhs = -1;
84 mapping.rhs_non_contracting_dims.back().rhs = i;
85 mapping.rhs_non_contracting_dims.back().output = next_output_dim++;
86 }
87 auto create_sharded_dot =
88 [&](HloInstruction* l, HloInstruction* r, SpmdBuilder* b,
89 const Window& conv_window) -> StatusOr<HloInstruction*> {
90 TF_ASSIGN_OR_RETURN(
91 auto sharded_dot_shape,
92 ShapeInference::InferDotOpShape(
93 l->shape(), r->shape(), hlo->dot_dimension_numbers(),
94 /*preferred_element_type=*/hlo->shape().element_type()));
95 return b->AddInstruction(HloInstruction::CreateDot(
96 sharded_dot_shape, l, r, hlo->dot_dimension_numbers(),
97 hlo->precision_config()));
98 };
99 return HandleDotHelper(hlo, mapping, create_sharded_dot);
100 }
101
102 namespace {
103
104 enum class WindowedEinsumOperand { LHS, RHS };
105
106 struct WindowedEinsumConfig {
107 WindowedEinsumOperand windowed_op;
108 bool windowed_at_contracting_dims;
109 bool windowed_at_batch_dims;
110 bool operands_sharded_at_contracting_dims;
111 };
112
113 struct DotDimensionIndexMapping {
114 std::vector<int64_t> lhs_to_rhs_indices;
115 std::vector<int64_t> lhs_to_output_indices;
116 std::vector<int64_t> rhs_to_lhs_indices;
117 std::vector<int64_t> rhs_to_output_indices;
118 std::vector<int64_t> output_to_lhs_indices;
119 std::vector<int64_t> output_to_rhs_indices;
120 };
121
UpdateDDNums(DotDimensionNumbers * new_ddnums,int64_t reshaped_dim,bool lhs)122 void UpdateDDNums(DotDimensionNumbers* new_ddnums, int64_t reshaped_dim,
123 bool lhs) {
124 auto update_dims =
125 [&reshaped_dim](tensorflow::protobuf::RepeatedField<int64_t>* dims) {
126 bool add_reshaped_dim = false;
127 if (absl::c_linear_search(*dims, reshaped_dim)) {
128 add_reshaped_dim = true;
129 }
130 for (int64_t i = 0; i < dims->size(); ++i) {
131 auto dim = dims->at(i);
132 if (reshaped_dim <= dim) {
133 dims->Set(i, dim + 1);
134 }
135 }
136 if (add_reshaped_dim) {
137 dims->Add(reshaped_dim);
138 }
139 };
140
141 if (lhs) {
142 update_dims(new_ddnums->mutable_lhs_contracting_dimensions());
143 update_dims(new_ddnums->mutable_lhs_batch_dimensions());
144 } else { // rhs
145 update_dims(new_ddnums->mutable_rhs_contracting_dimensions());
146 update_dims(new_ddnums->mutable_rhs_batch_dimensions());
147 }
148 }
149
GenNewWindow(const HloInstruction * original_dot,const HloInstruction * dot_lhs,const HloInstruction * dot_rhs,int64_t lhs_concat_dim,int64_t rhs_concat_dim,bool windowed_at_contracting_dims,bool windowed_at_batch_dims)150 Window GenNewWindow(const HloInstruction* original_dot,
151 const HloInstruction* dot_lhs,
152 const HloInstruction* dot_rhs, int64_t lhs_concat_dim,
153 int64_t rhs_concat_dim, bool windowed_at_contracting_dims,
154 bool windowed_at_batch_dims) {
155 auto new_window = original_dot->window();
156 const ConvolutionDimensionNumbers& conv_dnums =
157 original_dot->convolution_dimension_numbers();
158 if (lhs_concat_dim != -1) {
159 for (int64_t i = 0; i < conv_dnums.input_spatial_dimensions_size(); ++i) {
160 if (conv_dnums.input_spatial_dimensions(i) == lhs_concat_dim) {
161 auto wd = new_window.mutable_dimensions(i);
162 auto lhs_size = dot_lhs->shape().dimensions(lhs_concat_dim + 1);
163 if (windowed_at_contracting_dims) {
164 wd->set_size(lhs_size);
165 }
166 if (windowed_at_batch_dims) {
167 wd->set_size(lhs_size);
168 wd->set_padding_low(0);
169 wd->set_padding_high(0);
170 wd->set_stride(std::max<int64_t>(1, lhs_size - 1));
171 wd->set_window_dilation(1);
172 wd->set_base_dilation(lhs_size);
173 wd->set_window_reversal(false);
174 }
175 }
176 }
177 }
178 if (rhs_concat_dim != -1) {
179 for (int64_t i = 0; i < conv_dnums.kernel_spatial_dimensions_size(); ++i) {
180 if (conv_dnums.kernel_spatial_dimensions(i) == rhs_concat_dim &&
181 !windowed_at_contracting_dims && !windowed_at_batch_dims &&
182 lhs_concat_dim == -1) {
183 auto wd = new_window.mutable_dimensions(i);
184 auto rhs_size = dot_rhs->shape().dimensions(rhs_concat_dim + 1);
185 wd->set_size(rhs_size);
186 wd->set_padding_low(rhs_size - 1);
187 wd->set_padding_high(rhs_size - 1);
188 }
189 }
190 }
191 // Add the extra dimension to window.
192 WindowDimension* new_dim = new_window.add_dimensions();
193 if (windowed_at_contracting_dims) {
194 new_dim->set_size(2);
195 new_dim->set_padding_low(0);
196 new_dim->set_padding_high(0);
197 new_dim->set_stride(1);
198 new_dim->set_window_dilation(1);
199 new_dim->set_base_dilation(1);
200 new_dim->set_window_reversal(false);
201 } else if (windowed_at_batch_dims) {
202 new_dim->set_size(2);
203 new_dim->set_padding_low(0);
204 new_dim->set_padding_high(0);
205 new_dim->set_stride(1); // std::max<int64_t>(1, 2 - 1)
206 new_dim->set_window_dilation(1);
207 new_dim->set_base_dilation(2);
208 new_dim->set_window_reversal(false);
209 } else {
210 if (lhs_concat_dim != -1) {
211 new_dim->set_size(1);
212 new_dim->set_padding_low(0);
213 new_dim->set_padding_high(0);
214 new_dim->set_stride(1);
215 new_dim->set_window_dilation(1);
216 new_dim->set_base_dilation(1);
217 new_dim->set_window_reversal(false);
218 }
219 if (rhs_concat_dim != -1) {
220 new_dim->set_size(2); // rhs_size
221 new_dim->set_padding_low(1); // rhs_size - 1
222 new_dim->set_padding_high(1); // rhs_size - 1
223 new_dim->set_stride(1);
224 new_dim->set_window_dilation(1);
225 new_dim->set_base_dilation(1);
226 new_dim->set_window_reversal(true);
227 }
228 }
229
230 VLOG(2) << "new_window: " << new_window.ShortDebugString();
231 return new_window;
232 }
233
GenNewConvDNums(const HloInstruction * original_dot,const HloInstruction * dot_lhs,const HloInstruction * dot_rhs,int64_t lhs_concat_dim,int64_t rhs_concat_dim,bool windowed_at_contracting_dims,bool windowed_at_batch_dims,const std::vector<int64_t> & lhs_to_output_indices,const std::vector<int64_t> & rhs_to_output_indices,const Shape & new_dot_shape)234 ConvolutionDimensionNumbers GenNewConvDNums(
235 const HloInstruction* original_dot, const HloInstruction* dot_lhs,
236 const HloInstruction* dot_rhs, int64_t lhs_concat_dim,
237 int64_t rhs_concat_dim, bool windowed_at_contracting_dims,
238 bool windowed_at_batch_dims,
239 const std::vector<int64_t>& lhs_to_output_indices,
240 const std::vector<int64_t>& rhs_to_output_indices,
241 const Shape& new_dot_shape) {
242 // Generate the new conv dimension numbers.
243 const ConvolutionDimensionNumbers& dnums =
244 original_dot->convolution_dimension_numbers();
245 // Handle the LHS dimension numbers.
246 int64_t input_batch_dimension = dnums.input_batch_dimension();
247 int64_t input_feature_dimension = dnums.input_feature_dimension();
248 std::vector<int64_t> input_spatial_dimensions(
249 dnums.input_spatial_dimensions().begin(),
250 dnums.input_spatial_dimensions().end());
251 if (lhs_concat_dim != -1) {
252 if (lhs_concat_dim <= input_batch_dimension) {
253 input_batch_dimension++;
254 }
255 if (lhs_concat_dim <= input_feature_dimension) {
256 input_feature_dimension++;
257 }
258 for (int64_t i = 0; i < input_spatial_dimensions.size(); ++i) {
259 if (lhs_concat_dim <= input_spatial_dimensions[i]) {
260 input_spatial_dimensions[i]++;
261 }
262 }
263 input_spatial_dimensions.push_back(lhs_concat_dim);
264 }
265 if (rhs_concat_dim != -1 && !windowed_at_contracting_dims &&
266 !windowed_at_batch_dims) {
267 input_spatial_dimensions.push_back(dot_lhs->shape().dimensions_size() - 1);
268 }
269 // Handle the RHS dimension numbers.
270 int64_t kernel_input_feature_dimension =
271 dnums.kernel_input_feature_dimension();
272 int64_t kernel_output_feature_dimension =
273 dnums.kernel_output_feature_dimension();
274 std::vector<int64_t> kernel_spatial_dimensions(
275 dnums.kernel_spatial_dimensions().begin(),
276 dnums.kernel_spatial_dimensions().end());
277 if (rhs_concat_dim != -1) {
278 if (rhs_concat_dim <= kernel_input_feature_dimension) {
279 kernel_input_feature_dimension++;
280 }
281 if (rhs_concat_dim <= kernel_output_feature_dimension) {
282 kernel_output_feature_dimension++;
283 }
284 for (int64_t i = 0; i < kernel_spatial_dimensions.size(); ++i) {
285 if (rhs_concat_dim <= kernel_spatial_dimensions[i]) {
286 kernel_spatial_dimensions[i]++;
287 }
288 }
289 kernel_spatial_dimensions.push_back(rhs_concat_dim);
290 }
291 if (lhs_concat_dim != -1 && !windowed_at_contracting_dims &&
292 !windowed_at_batch_dims) {
293 kernel_spatial_dimensions.push_back(dot_rhs->shape().dimensions_size() - 1);
294 }
295 // Handle the Output dimension numbers.
296 int64_t output_batch_dimension = dnums.output_batch_dimension();
297 int64_t output_feature_dimension = dnums.output_feature_dimension();
298 std::vector<int64_t> output_spatial_dimensions(
299 dnums.output_spatial_dimensions().begin(),
300 dnums.output_spatial_dimensions().end());
301 if (!windowed_at_contracting_dims) {
302 auto output_slice_dim = lhs_concat_dim != -1
303 ? lhs_to_output_indices[lhs_concat_dim]
304 : rhs_to_output_indices[rhs_concat_dim];
305 if (output_slice_dim <= output_batch_dimension) {
306 output_batch_dimension++;
307 }
308 if (output_slice_dim <= output_feature_dimension) {
309 output_feature_dimension++;
310 }
311 for (int64_t i = 0; i < output_spatial_dimensions.size(); ++i) {
312 if (output_slice_dim <= output_spatial_dimensions[i]) {
313 output_spatial_dimensions[i]++;
314 }
315 }
316 output_spatial_dimensions.push_back(output_slice_dim);
317 } else {
318 output_spatial_dimensions.push_back(new_dot_shape.dimensions_size() - 1);
319 }
320 // Construct the new dot dimension numbers.
321 ConvolutionDimensionNumbers new_dnums;
322 new_dnums.set_input_batch_dimension(input_batch_dimension);
323 new_dnums.set_input_feature_dimension(input_feature_dimension);
324 for (auto dim : input_spatial_dimensions) {
325 new_dnums.add_input_spatial_dimensions(dim);
326 }
327 new_dnums.set_kernel_input_feature_dimension(kernel_input_feature_dimension);
328 new_dnums.set_kernel_output_feature_dimension(
329 kernel_output_feature_dimension);
330 for (auto dim : kernel_spatial_dimensions) {
331 new_dnums.add_kernel_spatial_dimensions(dim);
332 }
333 new_dnums.set_output_batch_dimension(output_batch_dimension);
334 new_dnums.set_output_feature_dimension(output_feature_dimension);
335 for (auto dim : output_spatial_dimensions) {
336 new_dnums.add_output_spatial_dimensions(dim);
337 }
338
339 return new_dnums;
340 }
341
ComputeDimensionIndexMapping(const DotConvDimsMapping & dims_mapping,int64_t lhs_rank,int64_t rhs_rank,int64_t output_rank)342 DotDimensionIndexMapping ComputeDimensionIndexMapping(
343 const DotConvDimsMapping& dims_mapping, int64_t lhs_rank, int64_t rhs_rank,
344 int64_t output_rank) {
345 std::vector<int64_t> lhs_to_rhs_indices(lhs_rank, -1);
346 std::vector<int64_t> lhs_to_output_indices(lhs_rank, -1);
347 std::vector<int64_t> rhs_to_lhs_indices(rhs_rank, -1);
348 std::vector<int64_t> rhs_to_output_indices(rhs_rank, -1);
349 std::vector<int64_t> output_to_lhs_indices(output_rank, -1);
350 std::vector<int64_t> output_to_rhs_indices(output_rank, -1);
351 auto populate_indices_mapping =
352 [&](const DotConvDimsMapping::DimsMapping& mapping) {
353 if (mapping.lhs >= 0) {
354 lhs_to_rhs_indices[mapping.lhs] = mapping.rhs;
355 lhs_to_output_indices[mapping.lhs] = mapping.output;
356 }
357 if (mapping.rhs >= 0) {
358 rhs_to_lhs_indices[mapping.rhs] = mapping.lhs;
359 rhs_to_output_indices[mapping.rhs] = mapping.output;
360 }
361 if (mapping.output >= 0) {
362 output_to_lhs_indices[mapping.output] = mapping.lhs;
363 output_to_rhs_indices[mapping.output] = mapping.rhs;
364 }
365 };
366 for (const auto& mapping : dims_mapping.batch_dims) {
367 populate_indices_mapping(mapping);
368 }
369 for (const auto& mapping : dims_mapping.contracting_dims) {
370 populate_indices_mapping(mapping);
371 }
372 for (const auto& mapping : dims_mapping.lhs_non_contracting_dims) {
373 populate_indices_mapping(mapping);
374 }
375 for (const auto& mapping : dims_mapping.rhs_non_contracting_dims) {
376 populate_indices_mapping(mapping);
377 }
378 for (const auto& mapping : dims_mapping.conv_spatial_dims) {
379 populate_indices_mapping(mapping);
380 }
381 return DotDimensionIndexMapping{lhs_to_rhs_indices, lhs_to_output_indices,
382 rhs_to_lhs_indices, rhs_to_output_indices,
383 output_to_lhs_indices, output_to_rhs_indices};
384 }
385
GetPartitionGroupsForReplication(const HloSharding & sharding,absl::Span<const int64_t> replication_dims)386 std::vector<std::vector<int64_t>> GetPartitionGroupsForReplication(
387 const HloSharding& sharding, absl::Span<const int64_t> replication_dims) {
388 int64_t group_size = 1;
389 for (int64_t i : replication_dims) {
390 group_size *= sharding.tile_assignment().dim(i);
391 }
392 std::vector<std::vector<int64_t>> partition_groups(
393 sharding.tile_assignment().num_elements() / group_size);
394 sharding.tile_assignment().Each(
395 [&](absl::Span<const int64_t> indices, int64_t partition) {
396 int64_t group_id = 0;
397 for (int64_t i = 0; i < indices.size(); ++i) {
398 if (!absl::c_linear_search(replication_dims, i)) {
399 group_id *= sharding.tile_assignment().dim(i);
400 group_id += indices[i];
401 }
402 }
403 partition_groups[group_id].push_back(partition);
404 });
405 return partition_groups;
406 }
407
408 // Returns true iff all of the following conditions are simultaneously true:
409 // 1) 'lhs/rhs_sharding' have different partition counts on a dimension in
410 // 'dims'.
411 // 2) 'lhs/rhs_sharding' BOTH have partitions on at least one dimension in
412 // 'dims'.
RequiresTransposeSharding(const HloSharding & lhs_sharding,const HloSharding & rhs_sharding,const std::vector<DotConvDimsMapping::DimsMapping> & dims)413 bool RequiresTransposeSharding(
414 const HloSharding& lhs_sharding, const HloSharding& rhs_sharding,
415 const std::vector<DotConvDimsMapping::DimsMapping>& dims) {
416 int64_t lhs_total_partitions = 1;
417 int64_t rhs_total_partitions = 1;
418 bool has_different_lhs_rhs_dim_sharding = false;
419 for (const auto& dim : dims) {
420 int64_t lhs_dim_partitions = lhs_sharding.tile_assignment().dim(dim.lhs);
421 lhs_total_partitions *= lhs_dim_partitions;
422
423 int64_t rhs_dim_partitions = rhs_sharding.tile_assignment().dim(dim.rhs);
424 rhs_total_partitions *= rhs_dim_partitions;
425
426 if (lhs_dim_partitions != rhs_dim_partitions) {
427 has_different_lhs_rhs_dim_sharding = true;
428 }
429 }
430 return lhs_total_partitions > 1 && rhs_total_partitions > 1 &&
431 has_different_lhs_rhs_dim_sharding;
432 }
433
GetWindowedEinsumConfiguration(int64_t num_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,int64_t rhs_contracting_partitions,int64_t rhs_non_contracting_partitions,int64_t rhs_batch_partitions,int64_t lhs_contracting_partitions,int64_t lhs_non_contracting_partitions,int64_t lhs_batch_partitions,int64_t rhs_shape_size,int64_t lhs_shape_size,int64_t output_shape_size,const SpmdPartitionerOptions & options,const std::optional<HloSharding> & output_sharding_transposed_to_match_lhs,const std::optional<HloSharding> & output_sharding_transposed_to_match_rhs,const std::optional<HloSharding> & lhs_sharding_transposed_to_match_rhs,const std::optional<HloSharding> & rhs_sharding_transposed_to_match_lhs,const HloSharding & lhs_sharding,const HloSharding & rhs_sharding,const Window & conv_window,const DotConvDimsMapping & dims_mapping,int64_t max_iterations=INT64_MAX,const HloInstruction * original_hlo=nullptr,PartitionedHlo * partitioned_lhs=nullptr,PartitionedHlo * partitioned_rhs=nullptr,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot={},SpmdBuilder * b=nullptr,HloModule * module=nullptr,SpmdPartitioningVisitor * visitor=nullptr)434 std::optional<WindowedEinsumConfig> GetWindowedEinsumConfiguration(
435 int64_t num_partitions, int64_t output_lhs_non_contracting_partitions,
436 int64_t output_rhs_non_contracting_partitions,
437 int64_t rhs_contracting_partitions, int64_t rhs_non_contracting_partitions,
438 int64_t rhs_batch_partitions, int64_t lhs_contracting_partitions,
439 int64_t lhs_non_contracting_partitions, int64_t lhs_batch_partitions,
440 int64_t rhs_shape_size, int64_t lhs_shape_size, int64_t output_shape_size,
441 const SpmdPartitionerOptions& options,
442 const std::optional<HloSharding>& output_sharding_transposed_to_match_lhs,
443 const std::optional<HloSharding>& output_sharding_transposed_to_match_rhs,
444 const std::optional<HloSharding>& lhs_sharding_transposed_to_match_rhs,
445 const std::optional<HloSharding>& rhs_sharding_transposed_to_match_lhs,
446 const HloSharding& lhs_sharding, const HloSharding& rhs_sharding,
447 const Window& conv_window, const DotConvDimsMapping& dims_mapping,
448 int64_t max_iterations = INT64_MAX,
449 const HloInstruction* original_hlo = nullptr,
450 PartitionedHlo* partitioned_lhs = nullptr,
451 PartitionedHlo* partitioned_rhs = nullptr,
452 const std::function<StatusOr<HloInstruction*>(
453 HloInstruction*, HloInstruction*, SpmdBuilder*,
454 const Window& conv_window)>& create_sharded_dot = {},
455 SpmdBuilder* b = nullptr, HloModule* module = nullptr,
456 SpmdPartitioningVisitor* visitor = nullptr) {
457 if (num_partitions > max_iterations) {
458 return std::nullopt;
459 }
460
461 const HloInstruction* lhs = nullptr;
462 const HloInstruction* rhs = nullptr;
463 if (original_hlo) {
464 lhs = original_hlo->operand(0);
465 rhs = original_hlo->operand(1);
466 }
467
468 // Determine if any of the users users have the same shardings that can allow
469 // reuse of the resharding for the operand with original_hlo.
470 auto check_users_sharding = [original_hlo](
__anon991ccdc40702( const HloInstruction* to_loop_over) 471 const HloInstruction* to_loop_over) {
472 if (to_loop_over->users().size() <= 1) {
473 return true;
474 }
475 constexpr int kAggressiveness = 3;
476 std::optional<HloSharding> original_ideal_sharding =
477 ShardingPropagation::GetShardingFromUser(*to_loop_over, *original_hlo,
478 kAggressiveness,
479 /*is_spmd=*/true);
480 // Default to perform collective matmul if GetShardingFromUser() couldn't
481 // determine the sharding.
482 if (!original_ideal_sharding) {
483 return true;
484 }
485 for (const HloInstruction* user : to_loop_over->users()) {
486 if (user == original_hlo) {
487 continue;
488 }
489 std::optional<HloSharding> from_user =
490 ShardingPropagation::GetShardingFromUser(*to_loop_over, *user,
491 kAggressiveness,
492 /*is_spmd=*/true);
493 // Could't determine sharding. Skip to next one and pretend it wouldn't
494 // share the resharding.
495 if (!from_user) {
496 continue;
497 }
498 // This user doesn't require resharding, so even if has different sharding
499 // than original_hlo its ok to do collective matmul.
500 if (*from_user == to_loop_over->sharding()) {
501 continue;
502 }
503 // Same sharding needed, so we would share the resharding. Do not do
504 // collective matmul.
505 if (*original_ideal_sharding == *from_user) {
506 return false;
507 }
508 }
509 return true;
510 };
511
512 // Disable windowed einsum when the overheads may overweigh the benefits.
513 // Specifically, when max(computation time, communication time after
514 // decomposition) + extra prologue or epilogue collecitve permute is longer
515 // than the sum of computation time and the original communication time
516 // which can use more communication links. This is checked with the premise
517 // that communication/computation is large enough. For super small
518 // communication/computation generated by unit tests, we always allow windowed
519 // einsum to have meaningful unit tests.
__anon991ccdc40802(bool lhs_needs_ag, bool rhs_needs_ag) 520 auto disable_windowed_einsum = [&](bool lhs_needs_ag, bool rhs_needs_ag) {
521 if (visitor == nullptr) {
522 return false;
523 }
524
525 double computation_time_in_ms = 0.0;
526 double communication_time_in_ms = 0.0;
527 HloInstruction* dot;
528 HloInstruction* collective;
529 if (lhs_needs_ag || rhs_needs_ag) {
530 CHECK(!lhs_needs_ag || !rhs_needs_ag);
531 auto new_lhs = lhs_needs_ag
532 ? PartitionedHlo(partitioned_lhs->hlo(),
533 partitioned_lhs->base_shape(),
534 partitioned_lhs->state())
535 .Reshard(HloSharding::Replicate())
536 : *partitioned_lhs;
537 auto new_rhs = rhs_needs_ag
538 ? PartitionedHlo(partitioned_rhs->hlo(),
539 partitioned_rhs->base_shape(),
540 partitioned_rhs->state())
541 .Reshard(HloSharding::Replicate())
542 : *partitioned_rhs;
543 dot = create_sharded_dot(new_lhs.hlo(), new_rhs.hlo(), b, conv_window)
544 .ValueOrDie();
545 computation_time_in_ms = visitor->GetComputationTimeInMilliSec(dot);
546
547 collective = lhs_needs_ag ? new_lhs.hlo() : new_rhs.hlo();
548 while (collective->opcode() != HloOpcode::kAllGather &&
549 collective->opcode() != HloOpcode::kAllReduce &&
550 collective->operand_count() > 0 &&
551 collective != (lhs_needs_ag ? partitioned_lhs->hlo()
552 : partitioned_rhs->hlo())) {
553 collective = collective->mutable_operand(0);
554 }
555 if (collective->opcode() == HloOpcode::kAllGather ||
556 collective->opcode() == HloOpcode::kAllReduce) {
557 communication_time_in_ms = visitor->GetCommunicationTimeInMilliSec(
558 ShapeUtil::ByteSizeOf(collective->shape()),
559 collective->replica_groups());
560 }
561 } else {
562 auto new_lhs =
563 PartitionedHlo(partitioned_lhs->hlo(), partitioned_lhs->base_shape(),
564 partitioned_lhs->state());
565 auto new_rhs =
566 PartitionedHlo(partitioned_rhs->hlo(), partitioned_rhs->base_shape(),
567 partitioned_rhs->state());
568
569 // Check if contracting dimension sharding requires lhs/rhs resharding.
570 if (RequiresTransposeSharding(lhs->sharding(), rhs->sharding(),
571 dims_mapping.contracting_dims) &&
572 rhs_sharding_transposed_to_match_lhs.has_value() &&
573 lhs_sharding_transposed_to_match_rhs.has_value()) {
574 if (ShapeSizeInBytes(lhs->shape()) < ShapeSizeInBytes(rhs->shape())) {
575 new_lhs = new_lhs.Reshard(*rhs_sharding_transposed_to_match_lhs);
576 } else {
577 new_rhs = new_rhs.Reshard(*lhs_sharding_transposed_to_match_rhs);
578 }
579 }
580 new_lhs = new_lhs.PadWithZero();
581 new_rhs = new_rhs.PadWithZero();
582
583 dot = create_sharded_dot(new_lhs.hlo(), new_rhs.hlo(), b, conv_window)
584 .ValueOrDie();
585 computation_time_in_ms = visitor->GetComputationTimeInMilliSec(dot);
586
587 std::vector<int64_t> lhs_contracting_dims;
588 lhs_contracting_dims.reserve(new_lhs.base_shape().rank());
589 for (const auto& cd : dims_mapping.contracting_dims) {
590 lhs_contracting_dims.push_back(cd.lhs);
591 }
592 collective = new_lhs.state().partitioner->AllReduceAlongShardingDims(
593 b, dot, new_lhs.sharding(), new_lhs.state().next_channel_id,
594 lhs_contracting_dims, new_lhs.state().collective_ops_creator,
595 MakeBinaryAdd(dot->shape().element_type(), module));
596 communication_time_in_ms = visitor->GetCommunicationTimeInMilliSec(
597 ShapeUtil::ByteSizeOf(dot->shape()), collective->replica_groups());
598 }
599
600 VLOG(2) << "collective: " << collective->ToString() << "\n"
601 << "dot: " << dot->ToString() << "\n"
602 << "num_partitions: " << num_partitions << "\n"
603 << "computation_time_in_ms: " << computation_time_in_ms
604 << " communication_time_in_ms: " << communication_time_in_ms;
605 double extra_collective_permute_time = 0.0;
606 if (communication_time_in_ms != 0.0) {
607 extra_collective_permute_time =
608 communication_time_in_ms *
609 visitor->GetCommunicationMultiplier(collective->replica_groups()) *
610 2 / num_partitions;
611 }
612 if (communication_time_in_ms > 1e-5 &&
613 (std::max(
614 computation_time_in_ms,
615 communication_time_in_ms * visitor->GetCommunicationMultiplier(
616 collective->replica_groups())) +
617 extra_collective_permute_time) >=
618 (computation_time_in_ms + communication_time_in_ms)) {
619 return true;
620 } else {
621 return false;
622 }
623 };
624
625 if (output_lhs_non_contracting_partitions == num_partitions &&
626 output_sharding_transposed_to_match_lhs == lhs_sharding &&
627 rhs_shape_size >=
628 options.threshold_for_windowed_einsum_mib * 1024 * 1024 &&
629 (!rhs || check_users_sharding(rhs)) &&
630 !disable_windowed_einsum(/*lhs_needs_ag=*/false, /*rhs_needs_ag=*/true)) {
631 if (rhs_contracting_partitions == num_partitions) {
632 return WindowedEinsumConfig{
633 /*windowed_op=*/WindowedEinsumOperand::RHS,
634 /*windowed_at_contracting_dims*/ true,
635 /*windowed_at_batch_dims=*/false,
636 /*operands_sharded_at_contracting_dims=*/false};
637 }
638 if (rhs_non_contracting_partitions == num_partitions) {
639 return WindowedEinsumConfig{
640 /*windowed_op=*/WindowedEinsumOperand::RHS,
641 /*windowed_at_contracting_dims*/ false,
642 /*windowed_at_batch_dims=*/false,
643 /*operands_sharded_at_contracting_dims=*/false};
644 }
645 if (rhs_batch_partitions == num_partitions) {
646 return WindowedEinsumConfig{
647 /*windowed_op=*/WindowedEinsumOperand::RHS,
648 /*windowed_at_contracting_dims*/ false,
649 /*windowed_at_batch_dims=*/true,
650 /*operands_sharded_at_contracting_dims=*/false};
651 }
652 }
653 if (output_rhs_non_contracting_partitions == num_partitions &&
654 output_sharding_transposed_to_match_rhs == rhs_sharding &&
655 lhs_shape_size >=
656 options.threshold_for_windowed_einsum_mib * 1024 * 1024 &&
657 (!lhs || check_users_sharding(lhs)) &&
658 !disable_windowed_einsum(/*lhs_needs_ag=*/true, /*rhs_needs_ag=*/false)) {
659 if (lhs_contracting_partitions == num_partitions) {
660 return WindowedEinsumConfig{
661 /*windowed_op=*/WindowedEinsumOperand::LHS,
662 /*windowed_at_contracting_dims*/ true,
663 /*windowed_at_batch_dims=*/false,
664 /*operands_sharded_at_contracting_dims=*/false};
665 }
666 if (lhs_non_contracting_partitions == num_partitions) {
667 return WindowedEinsumConfig{
668 /*windowed_op=*/WindowedEinsumOperand::LHS,
669 /*windowed_at_contracting_dims*/ false,
670 /*windowed_at_batch_dims=*/false,
671 /*operands_sharded_at_contracting_dims=*/false};
672 }
673 if (lhs_batch_partitions == num_partitions) {
674 return WindowedEinsumConfig{
675 /*windowed_op=*/WindowedEinsumOperand::LHS,
676 /*windowed_at_contracting_dims*/ false,
677 /*windowed_at_batch_dims=*/true,
678 /*operands_sharded_at_contracting_dims=*/false};
679 }
680 }
681 if (lhs_contracting_partitions == rhs_contracting_partitions &&
682 lhs_contracting_partitions == num_partitions &&
683 (output_lhs_non_contracting_partitions == num_partitions ||
684 output_rhs_non_contracting_partitions == num_partitions) &&
685 output_shape_size >=
686 options.threshold_for_windowed_einsum_mib * 1024 * 1024 &&
687 !disable_windowed_einsum(/*lhs_needs_ag=*/false,
688 /*rhs_needs_ag=*/false)) {
689 if (output_lhs_non_contracting_partitions == num_partitions) {
690 return WindowedEinsumConfig{
691 /*windowed_op=*/WindowedEinsumOperand::RHS,
692 /*windowed_at_contracting_dims*/ false,
693 /*windowed_at_batch_dims=*/false,
694 /*operands_sharded_at_contracting_dims=*/true};
695 }
696 if (output_rhs_non_contracting_partitions == num_partitions) {
697 return WindowedEinsumConfig{
698 /*windowed_op=*/WindowedEinsumOperand::LHS,
699 /*windowed_at_contracting_dims*/ false,
700 /*windowed_at_batch_dims=*/false,
701 /*operands_sharded_at_contracting_dims=*/true};
702 }
703 }
704 return std::nullopt;
705 }
706
GetLoopReplicaGroups(HloInstruction * while_loop)707 std::vector<ReplicaGroup> GetLoopReplicaGroups(HloInstruction* while_loop) {
708 std::vector<ReplicaGroup> groups;
709 for (auto inst : while_loop->while_body()->instructions()) {
710 if (inst->opcode() == HloOpcode::kCollectivePermute) {
711 std::vector<std::pair<int64_t, int64_t>> st_pairs =
712 inst->source_target_pairs();
713 std::vector<int64_t> source_index(st_pairs.size());
714 for (int64_t i = 0; i < st_pairs.size(); ++i) {
715 source_index[st_pairs[i].first] = i;
716 }
717
718 absl::flat_hash_set<int64_t> visited;
719 for (int64_t i = 0; i < st_pairs.size(); ++i) {
720 if (visited.contains(st_pairs[i].first)) {
721 continue;
722 }
723 std::vector<int64_t> replica_group;
724 int64_t source = st_pairs[i].first;
725 int64_t target = st_pairs[i].second;
726 replica_group.push_back(source);
727 replica_group.push_back(target);
728 visited.insert(source);
729 visited.insert(target);
730 while (target != source) {
731 target = st_pairs[source_index[target]].second;
732 if (target != source) {
733 replica_group.push_back(target);
734 visited.insert(target);
735 }
736 }
737 absl::c_sort(replica_group);
738 groups.emplace_back();
739 for (auto id : replica_group) {
740 groups.back().add_replica_ids(id);
741 }
742 }
743
744 VLOG(3) << "while loop: " << while_loop->name()
745 << ", replica groups: " << ReplicaGroupsToString(groups);
746 break;
747 }
748 }
749 return groups;
750 }
751
752 // We use a recursive approach where sets of matching dimensions are recognized
753 // one at a time. The base shapes and shardings can be changed during the
754 // recursion as we group devices together. So refer to the passed in shapes and
755 // shardings for inputs and output, and do not use shape inference.
756
PartitionBaseCase(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,int64_t lhs_batch_partitions,int64_t rhs_batch_partitions,int64_t output_batch_partitions,int64_t lhs_contracting_partitions,int64_t rhs_contracting_partitions,int64_t lhs_non_contracting_partitions,int64_t rhs_non_contracting_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops,bool may_reshard_without_detecting_match,SpmdPartitioningVisitor * visitor)757 StatusOr<HloInstruction*> PartitionBaseCase(
758 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
759 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
760 int64_t num_partitions,
761 const std::function<StatusOr<HloInstruction*>(
762 HloInstruction*, HloInstruction*, SpmdBuilder*,
763 const Window& conv_window)>& create_sharded_dot,
764 const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
765 int64_t lhs_batch_partitions, int64_t rhs_batch_partitions,
766 int64_t output_batch_partitions, int64_t lhs_contracting_partitions,
767 int64_t rhs_contracting_partitions, int64_t lhs_non_contracting_partitions,
768 int64_t rhs_non_contracting_partitions,
769 int64_t output_lhs_non_contracting_partitions,
770 int64_t output_rhs_non_contracting_partitions,
771 const SpmdPartitionerOptions& options, SpmdBuilder* b,
772 std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
773 windowed_dot_general_loops,
774 bool may_reshard_without_detecting_match,
775 SpmdPartitioningVisitor* visitor) {
776 const HloSharding& lhs_sharding = lhs.sharding();
777 const HloSharding& rhs_sharding = rhs.sharding();
778 if (lhs_sharding.ReplicateOnLastTileDim() ||
779 rhs_sharding.ReplicateOnLastTileDim() ||
780 output_sharding.ReplicateOnLastTileDim()) {
781 return nullptr;
782 }
783 DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping(
784 dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(),
785 output_base_shape.rank());
786 auto lhs_sharding_transposed_to_match_rhs =
787 hlo_sharding_util::TransposeShardingWithCollapsedDims(
788 lhs_sharding, indices_map.lhs_to_rhs_indices,
789 indices_map.rhs_to_lhs_indices);
790 auto rhs_sharding_transposed_to_match_lhs =
791 hlo_sharding_util::TransposeShardingWithCollapsedDims(
792 rhs_sharding, indices_map.rhs_to_lhs_indices,
793 indices_map.lhs_to_rhs_indices);
794 auto lhs_sharding_transposed_to_match_output =
795 hlo_sharding_util::TransposeShardingWithCollapsedDims(
796 lhs_sharding, indices_map.lhs_to_output_indices,
797 indices_map.output_to_lhs_indices);
798 auto rhs_sharding_transposed_to_match_output =
799 hlo_sharding_util::TransposeShardingWithCollapsedDims(
800 rhs_sharding, indices_map.rhs_to_output_indices,
801 indices_map.output_to_rhs_indices);
802 auto output_sharding_transposed_to_match_lhs =
803 hlo_sharding_util::TransposeShardingWithCollapsedDims(
804 output_sharding, indices_map.output_to_lhs_indices,
805 indices_map.lhs_to_output_indices);
806 auto output_sharding_transposed_to_match_rhs =
807 hlo_sharding_util::TransposeShardingWithCollapsedDims(
808 output_sharding, indices_map.output_to_rhs_indices,
809 indices_map.rhs_to_output_indices);
810
811 // LHS and RHS are partitioned the same way and only partitioned in batch
812 // dimensions.
813 if (lhs_batch_partitions == rhs_batch_partitions &&
814 rhs_batch_partitions == num_partitions &&
815 lhs_sharding_transposed_to_match_rhs == rhs_sharding) {
816 TF_ASSIGN_OR_RETURN(
817 auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
818 dot->set_sharding(*lhs_sharding_transposed_to_match_output);
819 return PartitionedHlo(dot, output_base_shape, lhs.state())
820 .Reshard(output_sharding)
821 .hlo();
822 }
823
824 // Try emit batch-partitioned einsum with one operand resharded. Returns
825 // partitioned HLO or nullptr if the attempt fails. If
826 // may_reshard_with_allreduce is false, reshard must be done using
827 // all-to-all/collective-permute; otherwise this attempt fails.
828 auto try_emit_output_batch_partitioned_einsum_with_reshard =
829 [&](bool may_reshard_with_allreduce) -> StatusOr<HloInstruction*> {
830 // LHS and output are batch partitioned in the same way.
831 if (lhs_batch_partitions == num_partitions &&
832 output_batch_partitions == num_partitions &&
833 lhs_sharding_transposed_to_match_output == output_sharding) {
834 if (!may_reshard_with_allreduce &&
835 !CanReshardWithCollectivePermute(
836 rhs.sharding(), *lhs_sharding_transposed_to_match_rhs) &&
837 !GetReshardAllToAllSourceTargetDims(
838 rhs.sharding(), *lhs_sharding_transposed_to_match_rhs)) {
839 return nullptr;
840 }
841 auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs);
842 TF_ASSIGN_OR_RETURN(
843 auto dot,
844 create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), b, conv_window));
845 return dot;
846 }
847 // RHS and output are batch partitioned in the same way.
848 if (rhs_batch_partitions == num_partitions &&
849 output_batch_partitions == num_partitions &&
850 rhs_sharding_transposed_to_match_output == output_sharding) {
851 if (!may_reshard_with_allreduce &&
852 !CanReshardWithCollectivePermute(
853 lhs.sharding(), *rhs_sharding_transposed_to_match_lhs) &&
854 !GetReshardAllToAllSourceTargetDims(
855 lhs.sharding(), *rhs_sharding_transposed_to_match_lhs)) {
856 return nullptr;
857 }
858 auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs);
859 TF_ASSIGN_OR_RETURN(
860 auto dot,
861 create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), b, conv_window));
862 return dot;
863 }
864 return nullptr;
865 };
866
867 {
868 // Try batch-parallel by resharding one operand, and not using all-reduce.
869 TF_ASSIGN_OR_RETURN(
870 HloInstruction * partitioned_dot,
871 try_emit_output_batch_partitioned_einsum_with_reshard(false));
872 if (partitioned_dot) {
873 return partitioned_dot;
874 }
875 }
876
877 // Try to emit windowed DotGeneral when one operand is partitioned in the same
878 // way as the output along non-contracting dimensions, but the other operand
879 // is tiled in other dimensions. Or both operands are partitioned in the same
880 // way along contracting dimensions, but the output is partitioned along
881 // non-contracting dimensions.
882 auto emit_windowed_dot_general =
883 [&](const WindowedEinsumConfig& einsum_config)
884 -> StatusOr<HloInstruction*> {
885 CHECK(!einsum_config.windowed_at_batch_dims ||
886 !einsum_config.windowed_at_contracting_dims);
887 const bool windowed_at_batch_dims = einsum_config.windowed_at_batch_dims;
888 const bool windowed_at_contracting_dims =
889 einsum_config.windowed_at_contracting_dims;
890 const bool operands_sharded_at_contracting_dims =
891 einsum_config.operands_sharded_at_contracting_dims;
892 auto unpadded_result_buffer_shape =
893 MakePartitionedShape(output_base_shape, output_sharding);
894 auto padded_result_buffer_shape = unpadded_result_buffer_shape;
895 const bool windowed_op_is_lhs =
896 einsum_config.windowed_op == WindowedEinsumOperand::LHS;
897 // For windowing at batch/non-contracting dims, we produce the result one
898 // partition at a time, so we need to pad the shape in case of uneven
899 // partitioning in order to make dynamic-update-slice in-bound.
900 if (!windowed_at_contracting_dims &&
901 !operands_sharded_at_contracting_dims) {
902 padded_result_buffer_shape = GetPaddedShapeForUnevenPartitioning(
903 padded_result_buffer_shape,
904 windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
905 : *rhs_sharding_transposed_to_match_output);
906 }
907 // Mask the padding area of the windowed operand with zero if there is
908 // uneven partitioning.
909 if (windowed_at_contracting_dims) {
910 auto& to_mask = windowed_op_is_lhs ? lhs : rhs;
911 to_mask = to_mask.PadWithZero();
912 }
913 if (operands_sharded_at_contracting_dims) {
914 lhs = lhs.PadWithZero();
915 rhs = rhs.PadWithZero();
916 }
917
918 // Check if contracting dimension sharding requires lhs/rhs resharding.
919 if (RequiresTransposeSharding(lhs.hlo()->sharding(), rhs.hlo()->sharding(),
920 dims_mapping.contracting_dims) &&
921 rhs_sharding_transposed_to_match_lhs.has_value() &&
922 lhs_sharding_transposed_to_match_rhs.has_value()) {
923 if (ShapeSizeInBytes(lhs.hlo()->shape()) <
924 ShapeSizeInBytes(rhs.hlo()->shape())) {
925 lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithZero();
926 } else {
927 rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithZero();
928 }
929 }
930
931 // Get slice sharding, sharding dim, and lhs/rhs concat dim.
932 const HloSharding* slice_sharding;
933 if (operands_sharded_at_contracting_dims) {
934 slice_sharding = windowed_op_is_lhs
935 ? &*output_sharding_transposed_to_match_rhs
936 : &*output_sharding_transposed_to_match_lhs;
937 } else if (windowed_at_contracting_dims || windowed_at_batch_dims) {
938 slice_sharding = windowed_op_is_lhs
939 ? &*lhs_sharding_transposed_to_match_rhs
940 : &*rhs_sharding_transposed_to_match_lhs;
941 } else {
942 slice_sharding = windowed_op_is_lhs
943 ? &*lhs_sharding_transposed_to_match_output
944 : &*rhs_sharding_transposed_to_match_output;
945 }
946 CHECK_EQ(Product(slice_sharding->tile_assignment().dimensions()),
947 num_partitions);
948 int64_t slice_sharding_dim = -1;
949 for (int64_t i = 0; i < slice_sharding->tile_assignment().num_dimensions();
950 ++i) {
951 if (slice_sharding->tile_assignment().dim(i) > 1) {
952 slice_sharding_dim = i;
953 break;
954 }
955 }
956 int64_t lhs_concat_dim = -1;
957 int64_t rhs_concat_dim = -1;
958 if (operands_sharded_at_contracting_dims) {
959 if (windowed_op_is_lhs) {
960 rhs_concat_dim = slice_sharding_dim;
961 } else {
962 lhs_concat_dim = slice_sharding_dim;
963 }
964 } else if (windowed_at_contracting_dims || windowed_at_batch_dims) {
965 lhs_concat_dim = windowed_op_is_lhs
966 ? indices_map.rhs_to_lhs_indices[slice_sharding_dim]
967 : slice_sharding_dim;
968 rhs_concat_dim = windowed_op_is_lhs
969 ? slice_sharding_dim
970 : indices_map.lhs_to_rhs_indices[slice_sharding_dim];
971 } else {
972 if (windowed_op_is_lhs) {
973 lhs_concat_dim = indices_map.output_to_lhs_indices[slice_sharding_dim];
974 } else {
975 rhs_concat_dim = indices_map.output_to_rhs_indices[slice_sharding_dim];
976 }
977 }
978
979 auto lhs_hlo = lhs.hlo();
980 auto rhs_hlo = rhs.hlo();
981 // Reshape lhs and rhs before the loop for bidirectional communication case.
982 if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0) {
983 if (lhs_concat_dim != -1 && windowed_op_is_lhs &&
984 !operands_sharded_at_contracting_dims) {
985 std::vector<int64_t> reshaped_dims(
986 lhs_hlo->shape().dimensions().begin(),
987 lhs_hlo->shape().dimensions().end());
988 reshaped_dims.insert(reshaped_dims.begin() + lhs_concat_dim, 1);
989 lhs_hlo = b->AddInstruction(HloInstruction::CreateReshape(
990 ShapeUtil::MakeShape(lhs_hlo->shape().element_type(),
991 reshaped_dims),
992 lhs_hlo));
993 }
994 if (rhs_concat_dim != -1 && !windowed_op_is_lhs &&
995 !operands_sharded_at_contracting_dims) {
996 std::vector<int64_t> reshaped_dims(
997 rhs_hlo->shape().dimensions().begin(),
998 rhs_hlo->shape().dimensions().end());
999 reshaped_dims.insert(reshaped_dims.begin() + rhs_concat_dim, 1);
1000 rhs_hlo = b->AddInstruction(HloInstruction::CreateReshape(
1001 ShapeUtil::MakeShape(rhs_hlo->shape().element_type(),
1002 reshaped_dims),
1003 rhs_hlo));
1004 }
1005 }
1006
1007 auto result_buffer = CreateZero(padded_result_buffer_shape, b);
1008 auto extra_buffer =
1009 (!(options.bidirectional_windowed_einsum && num_partitions % 4 == 0) ||
1010 operands_sharded_at_contracting_dims)
1011 ? CreateZero(padded_result_buffer_shape, b)
1012 : windowed_op_is_lhs ? lhs_hlo
1013 : rhs_hlo;
1014
1015 if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0 &&
1016 !operands_sharded_at_contracting_dims) {
1017 std::vector<std::pair<int64_t, int64_t>> pre_sd_pairs(num_partitions);
1018 for (int64_t source = 0; source < num_partitions; ++source) {
1019 // 0 -> 1, 1 -> 2, 2 -> 3, ...
1020 pre_sd_pairs[source] = {source, (source + 1) % num_partitions};
1021 }
1022 extra_buffer =
1023 lhs.state()
1024 .collective_ops_creator.create_cross_partition_collective_permute(
1025 b, extra_buffer, pre_sd_pairs,
1026 (*lhs.state().next_channel_id)++);
1027 }
1028
1029 auto iteration = b->AddInstruction(
1030 HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32_t>(0)));
1031
1032 // Create a while loop that computes one window per iteration. During each
1033 // iteration, each partition sends its input window to its neighbor using
1034 // collective-permute for the next iteration.
1035 SpmdBuilder body_b("windowed_dot_general_body", original_hlo);
1036
1037 // Generate partial results used by bidirectional algorithm.
1038 auto get_partial_bid_results =
1039 [&](HloInstruction* l, HloInstruction* r, HloInstruction* o,
1040 HloInstruction* extra_inout, HloInstruction* cw_cp_output,
1041 HloInstruction* i) -> StatusOr<std::vector<HloInstruction*>> {
1042 auto partition_id =
1043 lhs.state().collective_ops_creator.create_partition_id(&body_b);
1044 auto partition_count =
1045 body_b.AddInstruction(HloInstruction::CreateConstant(
1046 LiteralUtil::CreateR0<uint32_t>(num_partitions)));
1047 auto ccw_data_partition_id =
1048 body_b.AddInstruction(HloInstruction::CreateBinary(
1049 i->shape(), HloOpcode::kAdd, i, partition_id));
1050 auto cw_data_partition_id =
1051 body_b.AddInstruction(HloInstruction::CreateBinary(
1052 i->shape(), HloOpcode::kAdd, partition_count, partition_id));
1053 if (operands_sharded_at_contracting_dims) {
1054 ccw_data_partition_id =
1055 body_b.AddInstruction(HloInstruction::CreateBinary(
1056 i->shape(), HloOpcode::kAdd, ccw_data_partition_id,
1057 body_b.AddInstruction(HloInstruction::CreateConstant(
1058 LiteralUtil::CreateR0<uint32_t>(num_partitions / 2 + 1)))));
1059 cw_data_partition_id =
1060 body_b.AddInstruction(HloInstruction::CreateBinary(
1061 i->shape(), HloOpcode::kSubtract, cw_data_partition_id,
1062 body_b.AddInstruction(HloInstruction::CreateConstant(
1063 LiteralUtil::CreateR0<uint32_t>(num_partitions / 2)))));
1064 } else {
1065 cw_data_partition_id =
1066 body_b.AddInstruction(HloInstruction::CreateBinary(
1067 i->shape(), HloOpcode::kSubtract, cw_data_partition_id,
1068 CreateOne(cw_data_partition_id->shape(), &body_b)));
1069 }
1070 ccw_data_partition_id = body_b.AddInstruction(
1071 HloInstruction::CreateBinary(i->shape(), HloOpcode::kRemainder,
1072 ccw_data_partition_id, partition_count));
1073 cw_data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary(
1074 i->shape(), HloOpcode::kSubtract, cw_data_partition_id, i));
1075 cw_data_partition_id = body_b.AddInstruction(
1076 HloInstruction::CreateBinary(i->shape(), HloOpcode::kRemainder,
1077 cw_data_partition_id, partition_count));
1078
1079 DotDimensionNumbers new_ddnums;
1080 if (original_hlo->opcode() == HloOpcode::kDot) {
1081 new_ddnums = original_hlo->dot_dimension_numbers();
1082 }
1083
1084 auto dot_lhs = l;
1085 auto dot_rhs = r;
1086 auto original_dot_lhs = l;
1087 auto original_dot_rhs = r;
1088 // Recover original lhs and rhs, will not be used in real computation.
1089 if (lhs_concat_dim != -1 && windowed_op_is_lhs) {
1090 std::vector<int64_t> reshaped_dims(
1091 original_dot_lhs->shape().dimensions().begin(),
1092 original_dot_lhs->shape().dimensions().end());
1093 reshaped_dims.erase(reshaped_dims.begin() + lhs_concat_dim);
1094 original_dot_lhs = body_b.AddInstruction(HloInstruction::CreateReshape(
1095 ShapeUtil::MakeShape(original_dot_lhs->shape().element_type(),
1096 reshaped_dims),
1097 original_dot_lhs));
1098 }
1099 if (rhs_concat_dim != -1 && !windowed_op_is_lhs) {
1100 std::vector<int64_t> reshaped_dims(
1101 original_dot_rhs->shape().dimensions().begin(),
1102 original_dot_rhs->shape().dimensions().end());
1103 reshaped_dims.erase(reshaped_dims.begin() + rhs_concat_dim);
1104 original_dot_rhs = body_b.AddInstruction(HloInstruction::CreateReshape(
1105 ShapeUtil::MakeShape(original_dot_rhs->shape().element_type(),
1106 reshaped_dims),
1107 original_dot_rhs));
1108 }
1109
1110 if (windowed_at_contracting_dims || windowed_at_batch_dims ||
1111 operands_sharded_at_contracting_dims) {
1112 // Slice the matching operand according to the partitioned dimensions
1113 // on the windowed operand or the output.
1114 auto slice_operand = !windowed_op_is_lhs ? l : r;
1115
1116 // Pad the sharding dim first (then the concat dim) for correctness.
1117 auto sharding_dim_size =
1118 slice_operand->shape().dimensions(slice_sharding_dim);
1119 if (sharding_dim_size % num_partitions != 0) {
1120 slice_operand = PadBaseShapeBeforeUnevenTiledSharding(
1121 slice_operand, *slice_sharding, &body_b);
1122 }
1123
1124 // We do this by treating the matching operand as replicated, and
1125 // resharding it to match the windowed operand or the output.
1126 auto gen_slice = [&](HloInstruction* data_partition_id,
1127 bool ccw) -> HloInstruction* {
1128 std::vector<int64_t> new_dims;
1129 const int64_t dimensions_size =
1130 slice_operand->shape().dimensions_size();
1131 new_dims.reserve(dimensions_size + 1);
1132 for (int64_t i = 0; i < dimensions_size; ++i) {
1133 if (i == slice_sharding_dim) {
1134 new_dims.push_back(1);
1135 }
1136 new_dims.push_back(slice_operand->shape().dimensions(i));
1137 }
1138 auto reshaped_slice_operand =
1139 body_b.AddInstruction(HloInstruction::CreateReshape(
1140 ShapeUtil::MakeShape(slice_operand->shape().element_type(),
1141 new_dims),
1142 slice_operand));
1143 auto min = body_b.AddInstruction(
1144 HloInstruction::CreateConstant(LiteralUtil::MinValue(
1145 reshaped_slice_operand->shape().element_type())));
1146 std::vector<int64_t> min_padding(
1147 reshaped_slice_operand->shape().rank());
1148 auto padded_slice_operand = reshaped_slice_operand;
1149 auto padded_shape = padded_slice_operand->shape();
1150 int64_t padding_dim = slice_sharding_dim;
1151 padded_shape.set_dimensions(padding_dim, 2);
1152 if (ccw) {
1153 // ccw pad high
1154 PaddingConfig ccw_pad_config =
1155 window_util::MakeSymmetricPadding(min_padding);
1156 ccw_pad_config.mutable_dimensions(padding_dim)
1157 ->set_edge_padding_low(0);
1158 ccw_pad_config.mutable_dimensions(padding_dim)
1159 ->set_edge_padding_high(1);
1160 padded_slice_operand =
1161 body_b.AddInstruction(HloInstruction::CreatePad(
1162 padded_shape, padded_slice_operand, min, ccw_pad_config));
1163 } else {
1164 // cw pad low
1165 PaddingConfig cw_pad_config =
1166 window_util::MakeSymmetricPadding(min_padding);
1167 cw_pad_config.mutable_dimensions(padding_dim)
1168 ->set_edge_padding_low(1);
1169 cw_pad_config.mutable_dimensions(padding_dim)
1170 ->set_edge_padding_high(0);
1171 padded_slice_operand =
1172 body_b.AddInstruction(HloInstruction::CreatePad(
1173 padded_shape, padded_slice_operand, min, cw_pad_config));
1174 }
1175
1176 padded_slice_operand->set_sharding(HloSharding::Replicate());
1177 auto state = lhs.state();
1178 state.b = &body_b;
1179 state.partition_id = data_partition_id;
1180 state.reshard_cache->per_hlo_cache.erase(padded_slice_operand);
1181 auto padded_slice_sharding = hlo_sharding_util::ReshapeSharding(
1182 slice_operand->shape(), reshaped_slice_operand->shape(),
1183 *slice_sharding);
1184 auto padded_slice =
1185 PartitionedHlo(padded_slice_operand,
1186 padded_slice_operand->shape(), state)
1187 .Reshard(*padded_slice_sharding)
1188 .hlo();
1189 padded_slice_operand->clear_sharding();
1190 return padded_slice;
1191 };
1192
1193 auto ccw_slice = gen_slice(ccw_data_partition_id, true);
1194 auto cw_slice = gen_slice(cw_data_partition_id, false);
1195 auto slice = body_b.AddInstruction(HloInstruction::CreateBinary(
1196 ccw_slice->shape(), HloOpcode::kMaximum, ccw_slice, cw_slice));
1197 // Reshape. The reshaped slice will not be used to produce the final
1198 // result, but used as a hint for the shape inference.
1199 std::vector<int64_t> reshaped_slice_dims;
1200 const int64_t dim_size = slice->shape().dimensions_size();
1201 reshaped_slice_dims.reserve(dim_size);
1202 for (int64_t i = 0; i < dim_size; ++i) {
1203 auto dim_size = slice->shape().dimensions(i);
1204 if (i == (slice_sharding_dim + 1)) {
1205 reshaped_slice_dims.push_back(dim_size * 2);
1206 } else if (i != slice_sharding_dim) {
1207 reshaped_slice_dims.push_back(dim_size);
1208 }
1209 }
1210 auto reshaped_slice =
1211 body_b.AddInstruction(HloInstruction::CreateReshape(
1212 ShapeUtil::MakeShape(slice->shape().element_type(),
1213 reshaped_slice_dims),
1214 slice));
1215
1216 if (!windowed_op_is_lhs) {
1217 dot_lhs = slice;
1218 original_dot_lhs = reshaped_slice;
1219 if (original_hlo->opcode() == HloOpcode::kDot) {
1220 UpdateDDNums(&new_ddnums, slice_sharding_dim, true);
1221 }
1222 } else {
1223 dot_rhs = slice;
1224 original_dot_rhs = reshaped_slice;
1225 if (original_hlo->opcode() == HloOpcode::kDot) {
1226 UpdateDDNums(&new_ddnums, slice_sharding_dim, false);
1227 }
1228 }
1229 }
1230
1231 auto ccw_dot_lhs = l;
1232 auto ccw_dot_rhs = r;
1233 auto cw_dot_lhs = windowed_op_is_lhs ? extra_inout : l;
1234 auto cw_dot_rhs = windowed_op_is_lhs ? r : extra_inout;
1235 if (lhs_concat_dim != -1 && windowed_op_is_lhs) {
1236 // Concat
1237 auto lhs_concat_shape = ccw_dot_lhs->shape();
1238 lhs_concat_shape.set_dimensions(lhs_concat_dim, 2);
1239 dot_lhs = body_b.AddInstruction(HloInstruction::CreateConcatenate(
1240 lhs_concat_shape, {ccw_dot_lhs, cw_dot_lhs}, lhs_concat_dim));
1241
1242 std::vector<int64_t> reshaped_dims(
1243 ccw_dot_lhs->shape().dimensions().begin(),
1244 ccw_dot_lhs->shape().dimensions().end());
1245 reshaped_dims.erase(reshaped_dims.begin() + lhs_concat_dim);
1246 reshaped_dims[lhs_concat_dim] *= 2;
1247 original_dot_lhs = body_b.AddInstruction(HloInstruction::CreateReshape(
1248 ShapeUtil::MakeShape(dot_lhs->shape().element_type(),
1249 reshaped_dims),
1250 dot_lhs));
1251
1252 if (original_hlo->opcode() == HloOpcode::kDot) {
1253 UpdateDDNums(&new_ddnums, lhs_concat_dim, true);
1254 }
1255 }
1256 if (rhs_concat_dim != -1 && !windowed_op_is_lhs) {
1257 // Concat
1258 auto rhs_concat_shape = ccw_dot_rhs->shape();
1259 rhs_concat_shape.set_dimensions(rhs_concat_dim, 2);
1260 dot_rhs = body_b.AddInstruction(HloInstruction::CreateConcatenate(
1261 rhs_concat_shape, {ccw_dot_rhs, cw_dot_rhs}, rhs_concat_dim));
1262
1263 std::vector<int64_t> reshaped_dims(
1264 ccw_dot_rhs->shape().dimensions().begin(),
1265 ccw_dot_rhs->shape().dimensions().end());
1266 reshaped_dims.erase(reshaped_dims.begin() + rhs_concat_dim);
1267 reshaped_dims[rhs_concat_dim] *= 2;
1268 original_dot_rhs = body_b.AddInstruction(HloInstruction::CreateReshape(
1269 ShapeUtil::MakeShape(dot_rhs->shape().element_type(),
1270 reshaped_dims),
1271 dot_rhs));
1272
1273 if (original_hlo->opcode() == HloOpcode::kDot) {
1274 UpdateDDNums(&new_ddnums, rhs_concat_dim, false);
1275 }
1276 }
1277
1278 // The generated original dot will not be used.
1279 TF_ASSIGN_OR_RETURN(auto original_dot,
1280 create_sharded_dot(original_dot_lhs, original_dot_rhs,
1281 &body_b, conv_window));
1282 VLOG(2) << original_dot->ToString();
1283
1284 // Generate the correct shape of the new dot/conv.
1285 auto original_sharded_dot_shape = original_dot->shape();
1286 auto new_dot_shape = original_sharded_dot_shape;
1287 std::vector<int64_t> new_dims(new_dot_shape.dimensions().begin(),
1288 new_dot_shape.dimensions().end());
1289 if (!windowed_at_contracting_dims) {
1290 auto slice_dim =
1291 lhs_concat_dim != -1
1292 ? indices_map.lhs_to_output_indices[lhs_concat_dim]
1293 : indices_map.rhs_to_output_indices[rhs_concat_dim];
1294 new_dims[slice_dim] /= 2;
1295 new_dims.insert(new_dims.begin() + slice_dim, 2);
1296 } else if (original_hlo->opcode() != HloOpcode::kDot) {
1297 new_dims.push_back(1);
1298 }
1299 new_dot_shape =
1300 ShapeUtil::MakeShape(original_hlo->shape().element_type(), new_dims);
1301
1302 HloInstruction* dot;
1303 if (original_hlo->opcode() == HloOpcode::kDot) {
1304 dot = body_b.AddInstruction(HloInstruction::CreateDot(
1305 new_dot_shape, dot_lhs, dot_rhs, new_ddnums,
1306 original_hlo->precision_config()));
1307 } else {
1308 if (!windowed_at_contracting_dims && !windowed_at_batch_dims) {
1309 if (lhs_concat_dim != -1) {
1310 std::vector<int64_t> new_dims(dot_rhs->shape().dimensions().begin(),
1311 dot_rhs->shape().dimensions().end());
1312 new_dims.push_back(1);
1313 dot_rhs = body_b.AddInstruction(HloInstruction::CreateReshape(
1314 ShapeUtil::MakeShape(dot_rhs->shape().element_type(), new_dims),
1315 dot_rhs));
1316 }
1317 if (rhs_concat_dim != -1) {
1318 std::vector<int64_t> new_dims(dot_lhs->shape().dimensions().begin(),
1319 dot_lhs->shape().dimensions().end());
1320 new_dims.push_back(1);
1321 dot_lhs = body_b.AddInstruction(HloInstruction::CreateReshape(
1322 ShapeUtil::MakeShape(dot_lhs->shape().element_type(), new_dims),
1323 dot_lhs));
1324 }
1325 }
1326
1327 dot = body_b.AddInstruction(HloInstruction::CreateConvolve(
1328 new_dot_shape, dot_lhs, dot_rhs,
1329 original_dot->feature_group_count(),
1330 original_dot->batch_group_count(),
1331 GenNewWindow(original_dot, dot_lhs, dot_rhs, lhs_concat_dim,
1332 rhs_concat_dim, windowed_at_contracting_dims,
1333 windowed_at_batch_dims),
1334 GenNewConvDNums(original_dot, dot_lhs, dot_rhs, lhs_concat_dim,
1335 rhs_concat_dim, windowed_at_contracting_dims,
1336 windowed_at_batch_dims,
1337 indices_map.lhs_to_output_indices,
1338 indices_map.rhs_to_output_indices, new_dot_shape),
1339 original_dot->precision_config()));
1340 }
1341 VLOG(2) << dot->ToString();
1342
1343 if (windowed_at_contracting_dims) {
1344 if (original_hlo->opcode() != HloOpcode::kDot) {
1345 // Reshape to the original sharded dot shape.
1346 dot = body_b.AddInstruction(
1347 HloInstruction::CreateReshape(original_sharded_dot_shape, dot));
1348 }
1349
1350 // Accumulate the partial output to the result buffer.
1351 o = body_b.AddInstruction(
1352 HloInstruction::CreateBinary(o->shape(), HloOpcode::kAdd, o, dot));
1353 } else {
1354 // The windowing operand is partitioned along batch/non-contracting
1355 // dimensions, so we need a dynamic-update-slice to save the partial
1356 // output in the result buffer.
1357 auto slice_shape = dot->shape();
1358 auto slice_dim =
1359 lhs_concat_dim != -1
1360 ? indices_map.lhs_to_output_indices[lhs_concat_dim]
1361 : indices_map.rhs_to_output_indices[rhs_concat_dim];
1362 slice_shape.set_dimensions(slice_dim, 1);
1363 std::vector<int64_t> ccw_start_indices(dot->shape().rank(), 0);
1364 std::vector<int64_t> cw_start_indices(dot->shape().rank(), 0);
1365 cw_start_indices[slice_dim] = 1;
1366 auto ccw_dot = body_b.AddInstruction(HloInstruction::CreateSlice(
1367 slice_shape, dot, ccw_start_indices, slice_shape.dimensions(),
1368 std::vector<int64_t>(dot->shape().rank(), 1)));
1369 auto cw_dot = body_b.AddInstruction(HloInstruction::CreateSlice(
1370 slice_shape, dot, cw_start_indices, dot->shape().dimensions(),
1371 std::vector<int64_t>(dot->shape().rank(), 1)));
1372
1373 std::vector<int64_t> reshaped_dims(
1374 original_sharded_dot_shape.dimensions().begin(),
1375 original_sharded_dot_shape.dimensions().end());
1376 reshaped_dims[slice_dim] /= 2;
1377 ccw_dot = body_b.AddInstruction(HloInstruction::CreateReshape(
1378 ShapeUtil::MakeShape(ccw_dot->shape().element_type(),
1379 reshaped_dims),
1380 ccw_dot));
1381 cw_dot = body_b.AddInstruction(HloInstruction::CreateReshape(
1382 ShapeUtil::MakeShape(cw_dot->shape().element_type(), reshaped_dims),
1383 cw_dot));
1384
1385 if (operands_sharded_at_contracting_dims) {
1386 // Accumulate the partial output to the result buffer.
1387 o = body_b.AddInstruction(HloInstruction::CreateBinary(
1388 o->shape(), HloOpcode::kAdd, o, ccw_dot));
1389 cw_cp_output = body_b.AddInstruction(HloInstruction::CreateBinary(
1390 o->shape(), HloOpcode::kAdd, cw_cp_output, cw_dot));
1391 } else {
1392 auto ccw_offsets = MakePartitionOffsets(
1393 o->shape(),
1394 windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
1395 : *rhs_sharding_transposed_to_match_output,
1396 ccw_data_partition_id, &body_b);
1397 auto cw_offsets = MakePartitionOffsets(
1398 o->shape(),
1399 windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
1400 : *rhs_sharding_transposed_to_match_output,
1401 cw_data_partition_id, &body_b);
1402 o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1403 o->shape(), o, ccw_dot, ccw_offsets));
1404 o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1405 o->shape(), o, cw_dot, cw_offsets));
1406 }
1407 }
1408
1409 std::vector<HloInstruction*> partial_results;
1410 partial_results.push_back(o);
1411 partial_results.push_back(cw_cp_output);
1412 return partial_results;
1413 };
1414
1415 // Generate partial result used by unidirectional algorithm.
1416 auto get_partial_unid_result =
1417 [&](HloInstruction* l, HloInstruction* r, HloInstruction* o,
1418 HloInstruction* i) -> StatusOr<HloInstruction*> {
1419 auto partition_id =
1420 lhs.state().collective_ops_creator.create_partition_id(&body_b);
1421 auto data_partition_id =
1422 body_b.AddInstruction(HloInstruction::CreateBinary(
1423 i->shape(), HloOpcode::kAdd, i, partition_id));
1424 auto partition_count =
1425 body_b.AddInstruction(HloInstruction::CreateConstant(
1426 LiteralUtil::CreateR0<uint32_t>(num_partitions)));
1427 data_partition_id = body_b.AddInstruction(
1428 HloInstruction::CreateBinary(i->shape(), HloOpcode::kRemainder,
1429 data_partition_id, partition_count));
1430 auto dot_lhs = l;
1431 auto dot_rhs = r;
1432 if (windowed_at_contracting_dims || windowed_at_batch_dims ||
1433 operands_sharded_at_contracting_dims) {
1434 // Slice the matching operand according to the partitioned dimensions on
1435 // the windowed operand or the output.
1436 auto slice_operand = !windowed_op_is_lhs ? l : r;
1437 // We do this by treating the matching operand as replicated, and
1438 // resharding it to match the windowed operand or the output.
1439 slice_operand->set_sharding(HloSharding::Replicate());
1440 auto state = lhs.state();
1441 state.b = &body_b;
1442 state.partition_id = data_partition_id;
1443 state.reshard_cache->per_hlo_cache.erase(slice_operand);
1444 auto slice =
1445 PartitionedHlo(slice_operand, slice_operand->shape(), state)
1446 .Reshard(*slice_sharding)
1447 .hlo();
1448 slice_operand->clear_sharding();
1449 if (!windowed_op_is_lhs) {
1450 dot_lhs = slice;
1451 } else {
1452 dot_rhs = slice;
1453 }
1454 }
1455 TF_ASSIGN_OR_RETURN(
1456 auto dot, create_sharded_dot(dot_lhs, dot_rhs, &body_b, conv_window));
1457 if (windowed_at_contracting_dims ||
1458 operands_sharded_at_contracting_dims) {
1459 // Accumulate the partial output to the result buffer.
1460 o = body_b.AddInstruction(
1461 HloInstruction::CreateBinary(o->shape(), HloOpcode::kAdd, o, dot));
1462 } else {
1463 // The windowing operand is partitioned along batch/non-contracting
1464 // dimensions, so we need a dynamic-update-slice to save the partial
1465 // output in the result buffer.
1466 auto offsets = MakePartitionOffsets(
1467 o->shape(),
1468 windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
1469 : *rhs_sharding_transposed_to_match_output,
1470 data_partition_id, &body_b);
1471 o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1472 o->shape(), o, dot, offsets));
1473 }
1474 return o;
1475 };
1476
1477 auto param = body_b.AddInstruction(HloInstruction::CreateParameter(
1478 /*parameter_number=*/0,
1479 ShapeUtil::MakeTupleShapeWithPtrs(
1480 {&lhs_hlo->shape(), &rhs_hlo->shape(), &result_buffer->shape(),
1481 &extra_buffer->shape(), &iteration->shape()}),
1482 "param"));
1483 auto l = body_b.AddInstruction(
1484 HloInstruction::CreateGetTupleElement(lhs_hlo->shape(), param, 0));
1485 auto r = body_b.AddInstruction(
1486 HloInstruction::CreateGetTupleElement(rhs_hlo->shape(), param, 1));
1487 auto o = body_b.AddInstruction(HloInstruction::CreateGetTupleElement(
1488 result_buffer->shape(), param, 2));
1489 auto extra_inout = body_b.AddInstruction(
1490 HloInstruction::CreateGetTupleElement(extra_buffer->shape(), param, 3));
1491 auto i = body_b.AddInstruction(
1492 HloInstruction::CreateGetTupleElement(iteration->shape(), param, 4));
1493
1494 // The bidirectional collective permute implementation has loop unrolling
1495 // of degree 2, so num_partitions is required to be a multiple of 4.
1496 if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0) {
1497 std::vector<std::pair<int64_t, int64_t>> ccw_sd_pairs(num_partitions);
1498 for (int64_t source = 0; source < num_partitions; ++source) {
1499 // 0 -> n-1, 1 -> 0, 2 -> 1, ...
1500 ccw_sd_pairs[source] = {source,
1501 (source - 1 + num_partitions) % num_partitions};
1502 }
1503 std::vector<std::pair<int64_t, int64_t>> cw_sd_pairs(num_partitions);
1504 for (int64_t source = 0; source < num_partitions; ++source) {
1505 // 0 -> 1, 1 -> 2, 2 -> 3, ...
1506 cw_sd_pairs[source] = {source, (source + 1) % num_partitions};
1507 }
1508
1509 // Even number iteration.
1510 auto next_l = l;
1511 auto next_r = r;
1512 auto ccw_cp_input = operands_sharded_at_contracting_dims ? o
1513 : windowed_op_is_lhs ? l
1514 : r;
1515 auto ccw_cp_output =
1516 lhs.state()
1517 .collective_ops_creator.create_cross_partition_collective_permute(
1518 &body_b, ccw_cp_input, ccw_sd_pairs,
1519 (*lhs.state().next_channel_id)++);
1520 if (operands_sharded_at_contracting_dims) {
1521 o = ccw_cp_output;
1522 } else if (windowed_op_is_lhs) {
1523 next_l = ccw_cp_output;
1524 } else {
1525 next_r = ccw_cp_output;
1526 }
1527 auto cw_cp_input = extra_inout;
1528 auto cw_cp_output =
1529 lhs.state()
1530 .collective_ops_creator.create_cross_partition_collective_permute(
1531 &body_b, cw_cp_input, cw_sd_pairs,
1532 (*lhs.state().next_channel_id)++);
1533
1534 TF_ASSIGN_OR_RETURN(
1535 auto outputs,
1536 get_partial_bid_results(l, r, o, extra_inout, cw_cp_output, i));
1537 o = outputs[0];
1538 cw_cp_output = outputs[1];
1539
1540 // ++i
1541 i = body_b.AddInstruction(HloInstruction::CreateBinary(
1542 i->shape(), HloOpcode::kAdd, i, CreateOne(i->shape(), &body_b)));
1543
1544 // Odd number iteration.
1545 auto second_next_l = next_l;
1546 auto second_next_r = next_r;
1547 ccw_cp_input = operands_sharded_at_contracting_dims ? o
1548 : windowed_op_is_lhs ? next_l
1549 : next_r;
1550 ccw_cp_output =
1551 lhs.state()
1552 .collective_ops_creator.create_cross_partition_collective_permute(
1553 &body_b, ccw_cp_input, ccw_sd_pairs,
1554 (*lhs.state().next_channel_id)++);
1555 if (operands_sharded_at_contracting_dims) {
1556 o = ccw_cp_output;
1557 } else if (windowed_op_is_lhs) {
1558 second_next_l = ccw_cp_output;
1559 } else {
1560 second_next_r = ccw_cp_output;
1561 }
1562 auto next_cw_cp_input = cw_cp_output;
1563 auto next_cw_cp_output =
1564 lhs.state()
1565 .collective_ops_creator.create_cross_partition_collective_permute(
1566 &body_b, next_cw_cp_input, cw_sd_pairs,
1567 (*lhs.state().next_channel_id)++);
1568
1569 TF_ASSIGN_OR_RETURN(
1570 outputs, get_partial_bid_results(next_l, next_r, o, cw_cp_output,
1571 next_cw_cp_output, i));
1572 o = outputs[0];
1573 next_cw_cp_output = outputs[1];
1574
1575 // ++i
1576 i = body_b.AddInstruction(HloInstruction::CreateBinary(
1577 i->shape(), HloOpcode::kAdd, i, CreateOne(i->shape(), &body_b)));
1578
1579 body_b.AddInstruction(HloInstruction::CreateTuple(
1580 {second_next_l, second_next_r, o, next_cw_cp_output, i}));
1581
1582 } else if (options.unroll_windowed_einsum && num_partitions % 2 == 0) {
1583 if (operands_sharded_at_contracting_dims) {
1584 std::vector<std::pair<int64_t, int64_t>> output_sd_pairs(
1585 num_partitions);
1586 for (int64_t source = 0; source < num_partitions; ++source) {
1587 // 0 -> n-2, 1 -> n-1, 2 -> 0, ...
1588 output_sd_pairs[source] = {
1589 source, (source - 2 + num_partitions) % num_partitions};
1590 }
1591
1592 o = lhs.state()
1593 .collective_ops_creator
1594 .create_cross_partition_collective_permute(
1595 &body_b, o, output_sd_pairs,
1596 (*lhs.state().next_channel_id)++);
1597
1598 TF_ASSIGN_OR_RETURN(extra_inout,
1599 get_partial_unid_result(l, r, extra_inout, i));
1600
1601 extra_inout = lhs.state()
1602 .collective_ops_creator
1603 .create_cross_partition_collective_permute(
1604 &body_b, extra_inout, output_sd_pairs,
1605 (*lhs.state().next_channel_id)++);
1606
1607 // i+2
1608 i = body_b.AddInstruction(HloInstruction::CreateBinary(
1609 i->shape(), HloOpcode::kAdd, i,
1610 body_b.AddInstruction(HloInstruction::CreateConstant(
1611 LiteralUtil::CreateR0<uint32_t>(2)))));
1612 auto real_i = body_b.AddInstruction(HloInstruction::CreateBinary(
1613 i->shape(), HloOpcode::kAdd, i,
1614 body_b.AddInstruction(HloInstruction::CreateConstant(
1615 LiteralUtil::CreateR0<uint32_t>(1)))));
1616
1617 TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(l, r, o, real_i));
1618 body_b.AddInstruction(
1619 HloInstruction::CreateTuple({l, r, o, extra_inout, i}));
1620 } else {
1621 std::vector<std::pair<int64_t, int64_t>> sd_pairs(num_partitions);
1622 for (int64_t source = 0; source < num_partitions; ++source) {
1623 // 0 -> n-1, 1 -> 0, 2 -> 1, ...
1624 sd_pairs[source] = {source,
1625 (source - 1 + num_partitions) % num_partitions};
1626 }
1627
1628 // Even number iteration.
1629 auto next_l = l;
1630 auto next_r = r;
1631 auto cp_input = windowed_op_is_lhs ? l : r;
1632 auto cp_output = lhs.state()
1633 .collective_ops_creator
1634 .create_cross_partition_collective_permute(
1635 &body_b, cp_input, sd_pairs,
1636 (*lhs.state().next_channel_id)++);
1637 if (windowed_op_is_lhs) {
1638 next_l = cp_output;
1639 } else {
1640 next_r = cp_output;
1641 }
1642 TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(l, r, o, i));
1643
1644 // ++i
1645 i = body_b.AddInstruction(HloInstruction::CreateBinary(
1646 i->shape(), HloOpcode::kAdd, i,
1647 body_b.AddInstruction(HloInstruction::CreateConstant(
1648 LiteralUtil::CreateR0<uint32_t>(1)))));
1649
1650 // Odd number iteration.
1651 auto second_next_l = next_l;
1652 auto second_next_r = next_r;
1653 cp_input = windowed_op_is_lhs ? next_l : next_r;
1654 cp_output = lhs.state()
1655 .collective_ops_creator
1656 .create_cross_partition_collective_permute(
1657 &body_b, cp_input, sd_pairs,
1658 (*lhs.state().next_channel_id)++);
1659 if (windowed_op_is_lhs) {
1660 second_next_l = cp_output;
1661 } else {
1662 second_next_r = cp_output;
1663 }
1664 TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(next_l, next_r, o, i));
1665
1666 // ++i
1667 i = body_b.AddInstruction(HloInstruction::CreateBinary(
1668 i->shape(), HloOpcode::kAdd, i,
1669 body_b.AddInstruction(HloInstruction::CreateConstant(
1670 LiteralUtil::CreateR0<uint32_t>(1)))));
1671
1672 body_b.AddInstruction(HloInstruction::CreateTuple(
1673 {second_next_l, second_next_r, o, extra_inout, i}));
1674 }
1675 } else {
1676 auto real_i = i;
1677 if (operands_sharded_at_contracting_dims) {
1678 // For reduce-scatter case, start from the data_partition_id + 1 to make
1679 // the data_partition_id of the final data shard in each partition the
1680 // same as the corresponding partition_id.
1681 real_i = body_b.AddInstruction(HloInstruction::CreateBinary(
1682 real_i->shape(), HloOpcode::kAdd, real_i,
1683 CreateOne(real_i->shape(), &body_b)));
1684 }
1685 TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(l, r, o, real_i));
1686
1687 // ++i
1688 i = body_b.AddInstruction(HloInstruction::CreateBinary(
1689 i->shape(), HloOpcode::kAdd, i,
1690 body_b.AddInstruction(HloInstruction::CreateConstant(
1691 LiteralUtil::CreateR0<uint32_t>(1)))));
1692 auto has_more = body_b.AddInstruction(HloInstruction::CreateCompare(
1693 ShapeUtil::MakeShape(PRED, {}), i,
1694 body_b.AddInstruction(HloInstruction::CreateConstant(
1695 LiteralUtil::CreateR0<uint32_t>(num_partitions))),
1696 ComparisonDirection::kLt));
1697 // Collective-permute for the next window. We don't need it for the last
1698 // iteration, so we use a conditional around the collective-permute.
1699 HloInstruction* conditional;
1700 {
1701 SpmdBuilder cp_b("window_collective_permute", original_hlo);
1702 {
1703 auto p = cp_b.AddInstruction(HloInstruction::CreateParameter(
1704 0,
1705 operands_sharded_at_contracting_dims ? o->shape()
1706 : windowed_op_is_lhs ? l->shape()
1707 : r->shape(),
1708 "window"));
1709 std::vector<std::pair<int64_t, int64_t>> sd_pairs(num_partitions);
1710 for (int64_t source = 0; source < num_partitions; ++source) {
1711 // 0 -> n-1, 1 -> 0, 2 -> 1, ...
1712 sd_pairs[source] = {source,
1713 (source - 1 + num_partitions) % num_partitions};
1714 }
1715 lhs.state()
1716 .collective_ops_creator.create_cross_partition_collective_permute(
1717 &cp_b, p, sd_pairs, (*lhs.state().next_channel_id)++);
1718 }
1719 SpmdBuilder ncp_b("last_iteration_noop", original_hlo);
1720 {
1721 ncp_b.AddInstruction(HloInstruction::CreateParameter(
1722 0,
1723 operands_sharded_at_contracting_dims ? o->shape()
1724 : windowed_op_is_lhs ? l->shape()
1725 : r->shape(),
1726 "window"));
1727 }
1728 conditional = body_b.AddInstruction(HloInstruction::CreateConditional(
1729 operands_sharded_at_contracting_dims ? o->shape()
1730 : windowed_op_is_lhs ? l->shape()
1731 : r->shape(),
1732 has_more,
1733 operands_sharded_at_contracting_dims ? o
1734 : windowed_op_is_lhs ? l
1735 : r,
1736 module->AddEmbeddedComputation(cp_b.Build()),
1737 operands_sharded_at_contracting_dims ? o
1738 : windowed_op_is_lhs ? l
1739 : r,
1740 module->AddEmbeddedComputation(ncp_b.Build())));
1741 }
1742 if (operands_sharded_at_contracting_dims) {
1743 o = conditional;
1744 } else if (windowed_op_is_lhs) {
1745 l = conditional;
1746 } else {
1747 r = conditional;
1748 }
1749 body_b.AddInstruction(
1750 HloInstruction::CreateTuple({l, r, o, extra_inout, i}));
1751 }
1752
1753 SpmdBuilder cond_b("windowed_dot_general_cond", original_hlo);
1754 auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter(
1755 /*parameter_number=*/0,
1756 ShapeUtil::MakeTupleShapeWithPtrs(
1757 {&lhs_hlo->shape(), &rhs_hlo->shape(), &result_buffer->shape(),
1758 &extra_buffer->shape(), &iteration->shape()}),
1759 "param"));
1760 auto cond_i = cond_b.AddInstruction(HloInstruction::CreateGetTupleElement(
1761 iteration->shape(), cond_param, 4));
1762 int64_t adapted_num_partitions =
1763 (options.bidirectional_windowed_einsum && num_partitions % 4 == 0)
1764 ? num_partitions / 2
1765 : num_partitions;
1766 cond_b.AddInstruction(HloInstruction::CreateCompare(
1767 ShapeUtil::MakeShape(PRED, {}), cond_i,
1768 cond_b.AddInstruction(HloInstruction::CreateConstant(
1769 LiteralUtil::CreateR0<uint32_t>(adapted_num_partitions))),
1770 ComparisonDirection::kLt));
1771 auto while_loop = b->AddInstruction(HloInstruction::CreateWhile(
1772 cond_param->shape(), module->AddEmbeddedComputation(cond_b.Build()),
1773 module->AddEmbeddedComputation(body_b.Build()),
1774 b->AddInstruction(HloInstruction::CreateTuple(
1775 {lhs_hlo, rhs_hlo, result_buffer, extra_buffer, iteration}))));
1776 windowed_dot_general_loops->push_back(
1777 {while_loop, windowed_op_is_lhs ? 0 : 1, windowed_at_contracting_dims,
1778 windowed_at_batch_dims, operands_sharded_at_contracting_dims,
1779 num_partitions, GetLoopReplicaGroups(while_loop)});
1780 auto result = b->AddInstruction(HloInstruction::CreateGetTupleElement(
1781 result_buffer->shape(), while_loop, 2));
1782 if (((options.bidirectional_windowed_einsum && num_partitions % 4 == 0) ||
1783 (options.unroll_windowed_einsum && num_partitions % 2 == 0)) &&
1784 operands_sharded_at_contracting_dims) {
1785 std::vector<std::pair<int64_t, int64_t>> extra_sd_pairs(num_partitions);
1786 for (int64_t source = 0; source < num_partitions; ++source) {
1787 // 0 -> 1, 1 -> 2, 2 -> 3, ...
1788 extra_sd_pairs[source] = {source, (source + 1) % num_partitions};
1789 }
1790 auto extra_result =
1791 b->AddInstruction(HloInstruction::CreateGetTupleElement(
1792 extra_buffer->shape(), while_loop, 3));
1793 if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0) {
1794 extra_result = lhs.state()
1795 .collective_ops_creator
1796 .create_cross_partition_collective_permute(
1797 b, extra_result, extra_sd_pairs,
1798 (*lhs.state().next_channel_id)++);
1799 }
1800 if (options.unroll_windowed_einsum && num_partitions % 2 == 0) {
1801 result = lhs.state()
1802 .collective_ops_creator
1803 .create_cross_partition_collective_permute(
1804 b, result, extra_sd_pairs,
1805 (*lhs.state().next_channel_id)++);
1806 }
1807 result = b->AddInstruction(HloInstruction::CreateBinary(
1808 result->shape(), HloOpcode::kAdd, result, extra_result));
1809 }
1810 if (!ShapeUtil::Compatible(padded_result_buffer_shape,
1811 unpadded_result_buffer_shape)) {
1812 result = b->AddInstruction(HloInstruction::CreateSlice(
1813 unpadded_result_buffer_shape, result,
1814 std::vector<int64_t>(padded_result_buffer_shape.rank(), 0),
1815 unpadded_result_buffer_shape.dimensions(),
1816 std::vector<int64_t>(padded_result_buffer_shape.rank(), 1)));
1817 }
1818 return result;
1819 };
1820 // Hard limit on iteration count based on empirical data (above this amount
1821 // there's pretty significant overhead).
1822 constexpr int64_t kMaxIterations = 32;
1823 std::optional<WindowedEinsumConfig> e_config = GetWindowedEinsumConfiguration(
1824 num_partitions, output_lhs_non_contracting_partitions,
1825 output_rhs_non_contracting_partitions, rhs_contracting_partitions,
1826 rhs_non_contracting_partitions, rhs_batch_partitions,
1827 lhs_contracting_partitions, lhs_non_contracting_partitions,
1828 lhs_batch_partitions, ShapeSizeInBytes(rhs.base_shape()),
1829 ShapeSizeInBytes(lhs.base_shape()), ShapeSizeInBytes(output_base_shape),
1830 options, output_sharding_transposed_to_match_lhs,
1831 output_sharding_transposed_to_match_rhs,
1832 lhs_sharding_transposed_to_match_rhs,
1833 rhs_sharding_transposed_to_match_lhs, lhs_sharding, rhs_sharding,
1834 conv_window, dims_mapping, kMaxIterations, original_hlo, &lhs, &rhs,
1835 create_sharded_dot, b, module, visitor);
1836 if (e_config) {
1837 VLOG(2) << "Emit windowed dot.";
1838 return emit_windowed_dot_general(*e_config);
1839 }
1840
1841 {
1842 // Try batch-parallel by resharding one operand, and allowing all-reduce.
1843 TF_ASSIGN_OR_RETURN(
1844 HloInstruction * partitioned_dot,
1845 try_emit_output_batch_partitioned_einsum_with_reshard(true));
1846 if (partitioned_dot) {
1847 return partitioned_dot;
1848 }
1849 }
1850
1851 // LHS and RHS have the same partitioned contracting dimensions.
1852 if (lhs_contracting_partitions == rhs_contracting_partitions &&
1853 lhs_contracting_partitions == num_partitions) {
1854 if (!may_reshard_without_detecting_match &&
1855 !output_sharding.IsReplicated() &&
1856 output_sharding.NumTiles() != num_partitions) {
1857 // The output is not fully sliced; the recursive handling has better
1858 // pattern matching for reduce scatters in subgroups.
1859 return nullptr;
1860 }
1861 // Pad both sides with zero, since NaN at one side cannot be masked by zero
1862 // on the other side.
1863 if (ShapeSizeInBytes(lhs.base_shape()) <
1864 ShapeSizeInBytes(rhs.base_shape())) {
1865 lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithZero();
1866 rhs = rhs.PadWithZero();
1867 } else {
1868 lhs = lhs.PadWithZero();
1869 rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithZero();
1870 }
1871 TF_ASSIGN_OR_RETURN(
1872 auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
1873 std::vector<int64_t> lhs_contracting_dims;
1874 lhs_contracting_dims.reserve(lhs.base_shape().rank());
1875 for (const auto& cd : dims_mapping.contracting_dims) {
1876 lhs_contracting_dims.push_back(cd.lhs);
1877 }
1878 auto ar = lhs.state().partitioner->AllReduceAlongShardingDims(
1879 b, dot, lhs.sharding(), lhs.state().next_channel_id,
1880 lhs_contracting_dims, lhs.state().collective_ops_creator,
1881 MakeBinaryAdd(output_base_shape.element_type(), module));
1882 ar->set_sharding(HloSharding::Replicate());
1883 return PartitionedHlo(ar, output_base_shape, lhs.state())
1884 .Reshard(output_sharding)
1885 .hlo();
1886 }
1887
1888 // LHS and output have the same partitioned non-contracting dimensions.
1889 if (lhs_non_contracting_partitions == num_partitions &&
1890 output_lhs_non_contracting_partitions == num_partitions &&
1891 lhs_sharding_transposed_to_match_output == output_sharding) {
1892 auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo();
1893 TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs_replicated,
1894 b, conv_window));
1895 return dot;
1896 }
1897
1898 // RHS and output have the same partitioned non-contracting dimensions.
1899 if (rhs_non_contracting_partitions == num_partitions &&
1900 output_rhs_non_contracting_partitions == num_partitions &&
1901 rhs_sharding_transposed_to_match_output == output_sharding) {
1902 auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo();
1903 TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs_replicated, rhs.hlo(),
1904 b, conv_window));
1905 return dot;
1906 }
1907
1908 if (may_reshard_without_detecting_match) {
1909 // Output is batch partitioned.
1910 if (output_batch_partitions == num_partitions) {
1911 auto resharded_lhs =
1912 lhs.Reshard(*output_sharding_transposed_to_match_lhs);
1913 auto resharded_rhs =
1914 rhs.Reshard(*output_sharding_transposed_to_match_rhs);
1915 TF_ASSIGN_OR_RETURN(
1916 auto dot, create_sharded_dot(resharded_lhs.hlo(), resharded_rhs.hlo(),
1917 b, conv_window));
1918 return dot;
1919 }
1920 // Output is partitioned along LHS non-contracting dimensions.
1921 if (output_lhs_non_contracting_partitions == num_partitions) {
1922 auto resharded_lhs =
1923 lhs.Reshard(*output_sharding_transposed_to_match_lhs);
1924 auto replicated_rhs = rhs.Reshard(HloSharding::Replicate());
1925 TF_ASSIGN_OR_RETURN(
1926 auto dot, create_sharded_dot(resharded_lhs.hlo(),
1927 replicated_rhs.hlo(), b, conv_window));
1928 return dot;
1929 }
1930 // Output is partitioned along RHS non-contracting dimensions.
1931 if (output_rhs_non_contracting_partitions == num_partitions) {
1932 auto replicated_lhs = lhs.Reshard(HloSharding::Replicate());
1933 auto resharded_rhs =
1934 rhs.Reshard(*output_sharding_transposed_to_match_rhs);
1935 TF_ASSIGN_OR_RETURN(
1936 auto dot, create_sharded_dot(replicated_lhs.hlo(),
1937 resharded_rhs.hlo(), b, conv_window));
1938 return dot;
1939 }
1940 }
1941
1942 // Returns true if it is beneficial to reshard the operand at `operand_idx`
1943 // across the contracting dimension.
1944 const auto should_partition_contracting_dim = [&](int64_t operand_idx) {
1945 if (!output_sharding.IsReplicated()) {
1946 return false;
1947 }
1948
1949 if (operand_idx == 0) {
1950 // If LHS and output are replicated, we compare the cost of all-gather
1951 // on RHS vs all-reduce on the output.
1952 return (rhs_contracting_partitions == num_partitions) &&
1953 lhs.sharding().IsReplicated() &&
1954 ShapeUtil::ElementsIn(rhs.base_shape()) >
1955 ShapeUtil::ElementsIn(output_base_shape);
1956 } else {
1957 return (lhs_contracting_partitions == num_partitions) &&
1958 rhs.sharding().IsReplicated() &&
1959 ShapeUtil::ElementsIn(lhs.base_shape()) >
1960 ShapeUtil::ElementsIn(output_base_shape);
1961 }
1962 };
1963
1964 // When the output is replicated and one of the operands is partitioned along
1965 // contracting dimension, align the other operand to be partitioned along
1966 // the contracting dimensions.
1967 if (output_sharding.IsReplicated() && (should_partition_contracting_dim(0) ||
1968 should_partition_contracting_dim(1))) {
1969 if (should_partition_contracting_dim(0)) {
1970 lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithZero();
1971 rhs = rhs.PadWithZero();
1972 } else {
1973 lhs = lhs.PadWithZero();
1974 rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithZero();
1975 }
1976 TF_ASSIGN_OR_RETURN(
1977 auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
1978
1979 std::vector<int64_t> lhs_contracting_dims;
1980 lhs_contracting_dims.reserve(lhs.base_shape().rank());
1981 for (const auto& cd : dims_mapping.contracting_dims) {
1982 lhs_contracting_dims.push_back(cd.lhs);
1983 }
1984 return lhs.state().partitioner->AllReduceAlongShardingDims(
1985 b, dot, lhs.sharding(), lhs.state().next_channel_id,
1986 lhs_contracting_dims, lhs.state().collective_ops_creator,
1987 MakeBinaryAdd(output_base_shape.element_type(), module));
1988 }
1989 return nullptr;
1990 }
1991
1992 StatusOr<HloInstruction*> PartitionDot(
1993 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
1994 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
1995 int64_t num_partitions,
1996 const std::function<StatusOr<HloInstruction*>(
1997 HloInstruction*, HloInstruction*, SpmdBuilder*,
1998 const Window& conv_window)>& create_sharded_dot,
1999 const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
2000 const SpmdPartitionerOptions& options, SpmdBuilder* b,
2001 std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
2002 windowed_dot_general_loops,
2003 SpmdPartitioningVisitor* visitor);
2004
PartitionDotGroupOnBatch(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,int64_t lhs_contracting_partitions,int64_t rhs_contracting_partitions,int64_t lhs_non_contracting_partitions,int64_t rhs_non_contracting_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops,SpmdPartitioningVisitor * visitor)2005 StatusOr<HloInstruction*> PartitionDotGroupOnBatch(
2006 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
2007 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
2008 int64_t num_partitions, int64_t lhs_contracting_partitions,
2009 int64_t rhs_contracting_partitions, int64_t lhs_non_contracting_partitions,
2010 int64_t rhs_non_contracting_partitions,
2011 const std::function<StatusOr<HloInstruction*>(
2012 HloInstruction*, HloInstruction*, SpmdBuilder*,
2013 const Window& conv_window)>& create_sharded_dot,
2014 const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
2015 bool require_matching_devices_to_group,
2016 const SpmdPartitionerOptions& options, SpmdBuilder* b,
2017 std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
2018 windowed_dot_general_loops,
2019 SpmdPartitioningVisitor* visitor) {
2020 std::vector<std::pair<HloInstruction*, HloSharding>>
2021 top_level_sharding_to_reset;
2022 absl::Cleanup cleaner = [&] {
2023 for (auto& to_reset : top_level_sharding_to_reset) {
2024 to_reset.first->set_sharding(to_reset.second);
2025 }
2026 };
2027 std::vector<int64_t> lhs_dims;
2028 std::vector<int64_t> rhs_dims;
2029 std::vector<int64_t> output_dims;
2030 auto lhs_sharding_dims_adjusted_to_output =
2031 lhs.sharding().IsReplicated()
2032 ? std::vector<int64_t>(lhs.base_shape().rank(), 1)
2033 : lhs.sharding().tile_assignment().dimensions();
2034 auto rhs_sharding_dims_adjusted_to_output =
2035 rhs.sharding().IsReplicated()
2036 ? std::vector<int64_t>(rhs.base_shape().rank(), 1)
2037 : rhs.sharding().tile_assignment().dimensions();
2038 auto output_sharding_dims_adjusted_to_lhs =
2039 output_sharding.tile_assignment().dimensions();
2040 bool lhs_rhs_dims_matching = true;
2041 for (const auto& dim : dims_mapping.batch_dims) {
2042 lhs_dims.push_back(dim.lhs);
2043 rhs_dims.push_back(dim.rhs);
2044 output_dims.push_back(dim.output);
2045 if (lhs_sharding_dims_adjusted_to_output[dim.lhs] !=
2046 rhs_sharding_dims_adjusted_to_output[dim.rhs]) {
2047 lhs_rhs_dims_matching = false;
2048 }
2049 lhs_sharding_dims_adjusted_to_output[dim.lhs] =
2050 output_sharding.tile_assignment().dim(dim.output);
2051 rhs_sharding_dims_adjusted_to_output[dim.rhs] =
2052 output_sharding.tile_assignment().dim(dim.output);
2053 output_sharding_dims_adjusted_to_lhs[dim.output] =
2054 lhs.sharding().tile_assignment().dim(dim.lhs);
2055 }
2056 if (require_matching_devices_to_group && lhs_rhs_dims_matching) {
2057 lhs_rhs_dims_matching =
2058 rhs.sharding() ==
2059 UngroupSharding(AlignGroupsWith(
2060 hlo_sharding_util::GroupShardingOnDims(rhs.sharding(), rhs_dims),
2061 hlo_sharding_util::GroupShardingOnDims(lhs.sharding(), lhs_dims)));
2062 }
2063 auto output_grouped =
2064 hlo_sharding_util::GroupShardingOnDims(output_sharding, output_dims);
2065 PartitionedHlo per_group_lhs = lhs;
2066 PartitionedHlo per_group_rhs = rhs;
2067 if (lhs_rhs_dims_matching) {
2068 auto lhs_grouped =
2069 hlo_sharding_util::GroupShardingOnDims(lhs.sharding(), lhs_dims);
2070 auto rhs_grouped =
2071 hlo_sharding_util::GroupShardingOnDims(rhs.sharding(), rhs_dims);
2072 if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) >
2073 ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) {
2074 rhs_grouped = AlignGroupsWith(std::move(rhs_grouped), lhs_grouped);
2075 rhs = rhs.Reshard(UngroupSharding(rhs_grouped));
2076 } else {
2077 lhs_grouped = AlignGroupsWith(std::move(lhs_grouped), rhs_grouped);
2078 lhs = lhs.Reshard(UngroupSharding(lhs_grouped));
2079 }
2080 auto reshaped_output_tiling = output_sharding.tile_assignment();
2081 reshaped_output_tiling.Reshape(output_sharding_dims_adjusted_to_lhs);
2082 output_grouped = AlignGroupsWith(
2083 hlo_sharding_util::GroupShardingOnDims(
2084 output_sharding.ReplicateOnLastTileDim()
2085 ? HloSharding::PartialTile(reshaped_output_tiling)
2086 : HloSharding::Tile(reshaped_output_tiling),
2087 output_dims),
2088 lhs_grouped);
2089 auto per_group_partitioner_state = CreatePerGroupPartitioningState(
2090 lhs.state(), lhs_grouped.device_groups, b);
2091 top_level_sharding_to_reset.emplace_back(lhs.hlo(), lhs.sharding());
2092 lhs.hlo()->set_sharding(lhs_grouped.sharding);
2093 top_level_sharding_to_reset.emplace_back(rhs.hlo(), rhs.sharding());
2094 rhs.hlo()->set_sharding(rhs_grouped.sharding);
2095 CHECK(lhs.hlo() != rhs.hlo() ||
2096 lhs_grouped.sharding == rhs_grouped.sharding);
2097 per_group_lhs = PartitionedHlo(
2098 lhs.hlo(), GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()),
2099 per_group_partitioner_state);
2100 per_group_rhs = PartitionedHlo(
2101 rhs.hlo(), GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()),
2102 per_group_partitioner_state);
2103 } else {
2104 auto per_group_partitioner_state = CreatePerGroupPartitioningState(
2105 lhs.state(), output_grouped.device_groups, b);
2106 auto reshard_to_output_batch =
2107 [&](PartitionedHlo operand, absl::Span<const int64_t> batch_dims,
2108 absl::Span<const int64_t> contracting_dims,
2109 absl::Span<const int64_t> non_contracting_dims,
2110 int64_t contracting_dim_partitions,
2111 int64_t non_contracting_dim_partitions,
2112 int64_t other_contracting_dim_partitions,
2113 std::vector<int64_t>* sharding_dims_adjusted_to_output)
2114 -> std::optional<PartitionedHlo> {
2115 if (operand.sharding().IsTileMaximal()) {
2116 auto partially_sharded = PerGroupSliceFromReplicated(
2117 operand.Replicate().hlo(), operand.state().partition_id,
2118 output_grouped.device_groups, batch_dims,
2119 output_grouped.group_dim_sizes, b);
2120 partially_sharded->set_sharding(HloSharding::Replicate());
2121 return PartitionedHlo(partially_sharded, partially_sharded->shape(),
2122 per_group_partitioner_state);
2123 }
2124 auto reshaped_tiling = operand.sharding().tile_assignment();
2125 // It's possible that the operand is not initially sharded on batch
2126 // dimensions in the same way as the output, although being tiled. In that
2127 // case, the current sharding_dims_adjusted_to_output may contain more
2128 // partitions than available devices. We remove partitioning on other
2129 // dimensions.
2130 if (Product(*sharding_dims_adjusted_to_output) >
2131 reshaped_tiling.num_elements()) {
2132 if (Product(*sharding_dims_adjusted_to_output) %
2133 reshaped_tiling.num_elements() !=
2134 0) {
2135 return std::nullopt;
2136 }
2137 int64_t ratio = Product(*sharding_dims_adjusted_to_output) /
2138 reshaped_tiling.num_elements();
2139 if (operand.sharding().ReplicateOnLastTileDim() &&
2140 reshaped_tiling.dimensions().back() % ratio == 0) {
2141 sharding_dims_adjusted_to_output->back() /= ratio;
2142 if (sharding_dims_adjusted_to_output->back() == 1) {
2143 sharding_dims_adjusted_to_output->pop_back();
2144 }
2145 } else if (ratio == non_contracting_dim_partitions &&
2146 (ratio != contracting_dim_partitions ||
2147 contracting_dim_partitions ==
2148 other_contracting_dim_partitions)) {
2149 for (int64_t dim : non_contracting_dims) {
2150 (*sharding_dims_adjusted_to_output)[dim] = 1;
2151 }
2152 } else if (ratio == contracting_dim_partitions) {
2153 for (int64_t dim : contracting_dims) {
2154 (*sharding_dims_adjusted_to_output)[dim] = 1;
2155 }
2156 } else {
2157 return std::nullopt;
2158 }
2159 }
2160 // If the operand is initially sharded more ways than the output in the
2161 // batch dimensions, sharding_dims_adjusted_to_output currently contains
2162 // fewer partitions than available devices. We do not handle this case.
2163 if (Product(*sharding_dims_adjusted_to_output) <
2164 reshaped_tiling.num_elements()) {
2165 return std::nullopt;
2166 }
2167 reshaped_tiling.Reshape(*sharding_dims_adjusted_to_output);
2168 auto grouped =
2169 AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(
2170 operand.base_shape().rank() <
2171 sharding_dims_adjusted_to_output->size()
2172 ? HloSharding::PartialTile(reshaped_tiling)
2173 : HloSharding::Tile(reshaped_tiling),
2174 batch_dims),
2175 output_grouped);
2176 if (require_matching_devices_to_group &&
2177 operand.sharding() != UngroupSharding(grouped)) {
2178 return std::nullopt;
2179 }
2180 auto resharded = operand.Reshard(UngroupSharding(grouped));
2181 top_level_sharding_to_reset.emplace_back(resharded.hlo(),
2182 resharded.sharding());
2183 resharded.hlo()->set_sharding(grouped.sharding);
2184 return PartitionedHlo(resharded.hlo(),
2185 GetPerGroupBaseShape(grouped, operand.base_shape()),
2186 per_group_partitioner_state);
2187 };
2188 std::vector<int64_t> lhs_contracting_dims;
2189 std::vector<int64_t> rhs_contracting_dims;
2190 lhs_contracting_dims.reserve(dims_mapping.contracting_dims.size());
2191 rhs_contracting_dims.reserve(dims_mapping.contracting_dims.size());
2192 for (const auto& dim : dims_mapping.contracting_dims) {
2193 lhs_contracting_dims.push_back(dim.lhs);
2194 rhs_contracting_dims.push_back(dim.rhs);
2195 }
2196 std::vector<int64_t> lhs_non_contracting_dims;
2197 std::vector<int64_t> rhs_non_contracting_dims;
2198 lhs_non_contracting_dims.reserve(
2199 dims_mapping.lhs_non_contracting_dims.size());
2200 rhs_non_contracting_dims.reserve(
2201 dims_mapping.rhs_non_contracting_dims.size());
2202 for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
2203 lhs_non_contracting_dims.push_back(dim.lhs);
2204 }
2205 for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
2206 rhs_non_contracting_dims.push_back(dim.rhs);
2207 }
2208 if (auto resharded = reshard_to_output_batch(
2209 lhs, lhs_dims, lhs_contracting_dims, lhs_non_contracting_dims,
2210 lhs_contracting_partitions, lhs_non_contracting_partitions,
2211 rhs_contracting_partitions,
2212 &lhs_sharding_dims_adjusted_to_output)) {
2213 per_group_lhs = *resharded;
2214 } else {
2215 return nullptr;
2216 }
2217 if (auto resharded = reshard_to_output_batch(
2218 rhs, rhs_dims, rhs_contracting_dims, rhs_non_contracting_dims,
2219 rhs_contracting_partitions, rhs_non_contracting_partitions,
2220 lhs_contracting_partitions,
2221 &rhs_sharding_dims_adjusted_to_output)) {
2222 per_group_rhs = *resharded;
2223 } else {
2224 return nullptr;
2225 }
2226 CHECK(lhs.hlo() != rhs.hlo() ||
2227 per_group_lhs.sharding() == per_group_rhs.sharding());
2228 }
2229 TF_ASSIGN_OR_RETURN(
2230 auto dot,
2231 PartitionDot(per_group_lhs, per_group_rhs,
2232 GetPerGroupBaseShape(output_grouped, output_base_shape),
2233 output_grouped.sharding, dims_mapping,
2234 num_partitions / output_grouped.device_groups.size(),
2235 create_sharded_dot, conv_window, module, original_hlo,
2236 options, b, windowed_dot_general_loops, visitor));
2237 dot->set_sharding(UngroupSharding(output_grouped));
2238 return PartitionedHlo(dot, output_base_shape, lhs.state())
2239 .Reshard(output_sharding)
2240 .hlo();
2241 }
2242
GetNonContractingPartitionGroupedShardingForMatchedOperand(bool lhs_matching,const HloSharding & matching_sharding,const HloSharding & output_sharding,absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_dims)2243 GroupedSharding GetNonContractingPartitionGroupedShardingForMatchedOperand(
2244 bool lhs_matching, const HloSharding& matching_sharding,
2245 const HloSharding& output_sharding,
2246 absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_dims) {
2247 std::vector<int64_t> matching_sharding_dims =
2248 matching_sharding.tile_assignment().dimensions();
2249 std::vector<int64_t> matching_dims;
2250 std::vector<int64_t> output_dims;
2251 // Make sure the partitioning on matching's non-contracting dimensions
2252 // defines the same device groups for both matching and output.
2253 for (const auto& dim : partitioned_dims) {
2254 int64_t md = lhs_matching ? dim.lhs : dim.rhs;
2255 matching_sharding_dims[md] =
2256 output_sharding.tile_assignment().dim(dim.output);
2257 matching_dims.push_back(md);
2258 output_dims.push_back(dim.output);
2259 }
2260 GroupedSharding output_grouped =
2261 hlo_sharding_util::GroupShardingOnDims(output_sharding, output_dims);
2262 Array<int64_t> reshaped_matching_tiling = matching_sharding.tile_assignment();
2263 reshaped_matching_tiling.Reshape(matching_sharding_dims);
2264 return AlignGroupsWith(
2265 hlo_sharding_util::GroupShardingOnDims(
2266 matching_sharding.ReplicateOnLastTileDim()
2267 ? HloSharding::PartialTile(reshaped_matching_tiling)
2268 : HloSharding::Tile(reshaped_matching_tiling),
2269 matching_dims),
2270 output_grouped);
2271 }
2272
2273 std::optional<GroupedSharding>
GetNonContractingPartitionGroupedShardingForOtherOperand(bool lhs_matching,const Shape & output_base_shape,const Shape & other_shape,int64_t other_contracting_partitions,int64_t other_non_contracting_partitions,int64_t matching_contracting_partitions,int64_t output_other_non_contracting_partitions,const HloSharding & other_sharding,const HloSharding & output_sharding,absl::Span<const DotConvDimsMapping::DimsMapping> matching_partitioned_dims,absl::Span<const DotConvDimsMapping::DimsMapping> other_non_contracting_dims,absl::Span<const DotConvDimsMapping::DimsMapping> other_contracting_dims)2274 GetNonContractingPartitionGroupedShardingForOtherOperand(
2275 bool lhs_matching, const Shape& output_base_shape, const Shape& other_shape,
2276 int64_t other_contracting_partitions,
2277 int64_t other_non_contracting_partitions,
2278 int64_t matching_contracting_partitions,
2279 int64_t output_other_non_contracting_partitions,
2280 const HloSharding& other_sharding, const HloSharding& output_sharding,
2281 absl::Span<const DotConvDimsMapping::DimsMapping> matching_partitioned_dims,
2282 absl::Span<const DotConvDimsMapping::DimsMapping>
2283 other_non_contracting_dims,
2284 absl::Span<const DotConvDimsMapping::DimsMapping> other_contracting_dims) {
2285 int64_t group_count = 1;
2286 std::vector<int64_t> output_dims;
2287 output_dims.reserve(matching_partitioned_dims.size());
2288 for (const auto& dim : matching_partitioned_dims) {
2289 output_dims.push_back(dim.output);
2290 group_count *= output_sharding.tile_assignment().dim(dim.output);
2291 }
2292 GroupedSharding output_grouped =
2293 hlo_sharding_util::GroupShardingOnDims(output_sharding, output_dims);
2294 std::vector<int64_t> other_group_dims;
2295 if (other_sharding.ReplicateOnLastTileDim() &&
2296 other_sharding.tile_assignment().dimensions().back() % group_count == 0) {
2297 other_group_dims.push_back(
2298 other_sharding.tile_assignment().num_dimensions() - 1);
2299 } else {
2300 const bool may_replicate_other_contracting_dims =
2301 (other_contracting_partitions == group_count &&
2302 other_non_contracting_partitions ==
2303 output_other_non_contracting_partitions);
2304 const bool may_replicate_other_non_contracting_dims =
2305 group_count == other_non_contracting_partitions &&
2306 matching_contracting_partitions == other_contracting_partitions;
2307 if (auto found_dims = FindMatchingPartitionedDimsForGrouping(
2308 other_sharding, output_grouped.device_groups)) {
2309 other_group_dims = std::move(*found_dims);
2310 } else if (may_replicate_other_contracting_dims &&
2311 (!may_replicate_other_non_contracting_dims ||
2312 ShapeUtil::ByteSizeOf(other_shape)) <=
2313 ShapeUtil::ByteSizeOf(MakePartitionedShape(
2314 output_base_shape, output_sharding))) {
2315 for (const auto& dim : other_contracting_dims) {
2316 other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs);
2317 }
2318 } else if (may_replicate_other_non_contracting_dims) {
2319 for (const auto& dim : other_non_contracting_dims) {
2320 other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs);
2321 }
2322 } else {
2323 return std::nullopt;
2324 }
2325 }
2326 if (other_group_dims.size() == 1 &&
2327 other_group_dims[0] ==
2328 other_sharding.tile_assignment().num_dimensions() - 1) {
2329 std::vector<int64_t> group_dim_shards = {
2330 other_sharding.tile_assignment().dimensions().back() / group_count};
2331 return AlignGroupsWith(
2332 hlo_sharding_util::GroupShardingOnDims(
2333 other_sharding, {other_group_dims[0]}, group_dim_shards),
2334 output_grouped, /*ignore_group_order=*/true);
2335
2336 } else if (!other_sharding.IsReplicated()) {
2337 return AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(
2338 other_sharding, other_group_dims),
2339 output_grouped,
2340 /*ignore_group_order=*/true);
2341 }
2342 return std::nullopt;
2343 }
2344
PartitionDotGroupOnNonContracting(bool lhs_matching,PartitionedHlo matching,PartitionedHlo other,int64_t matching_contracting_partitions,int64_t other_contracting_partitions,absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_non_contracting_dims,int64_t other_non_contracting_partitions,int64_t output_other_non_contracting_partitions,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops,SpmdPartitioningVisitor * visitor)2345 StatusOr<HloInstruction*> PartitionDotGroupOnNonContracting(
2346 bool lhs_matching, PartitionedHlo matching, PartitionedHlo other,
2347 int64_t matching_contracting_partitions,
2348 int64_t other_contracting_partitions,
2349 absl::Span<const DotConvDimsMapping::DimsMapping>
2350 partitioned_non_contracting_dims,
2351 int64_t other_non_contracting_partitions,
2352 int64_t output_other_non_contracting_partitions,
2353 const Shape& output_base_shape, const HloSharding& output_sharding,
2354 const DotConvDimsMapping& dims_mapping, int64_t num_partitions,
2355 const std::function<StatusOr<HloInstruction*>(
2356 HloInstruction*, HloInstruction*, SpmdBuilder*,
2357 const Window& conv_window)>& create_sharded_dot,
2358 const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
2359 bool require_matching_devices_to_group,
2360 const SpmdPartitionerOptions& options, SpmdBuilder* b,
2361 std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
2362 windowed_dot_general_loops,
2363 SpmdPartitioningVisitor* visitor) {
2364 std::vector<std::pair<HloInstruction*, HloSharding>>
2365 top_level_sharding_to_reset;
2366 absl::Cleanup cleaner = [&] {
2367 for (auto& to_reset : top_level_sharding_to_reset) {
2368 to_reset.first->set_sharding(to_reset.second);
2369 }
2370 };
2371
2372 std::vector<int64_t> output_dims;
2373 output_dims.reserve(partitioned_non_contracting_dims.size());
2374 for (const auto& dim : partitioned_non_contracting_dims) {
2375 output_dims.push_back(dim.output);
2376 }
2377 GroupedSharding output_grouped =
2378 hlo_sharding_util::GroupShardingOnDims(output_sharding, output_dims);
2379 GroupedSharding matching_grouped =
2380 GetNonContractingPartitionGroupedShardingForMatchedOperand(
2381 lhs_matching, matching.sharding(), output_sharding,
2382 partitioned_non_contracting_dims);
2383 if (require_matching_devices_to_group &&
2384 matching.sharding() != UngroupSharding(matching_grouped)) {
2385 return nullptr;
2386 }
2387 std::optional<GroupedSharding> other_grouped =
2388 GetNonContractingPartitionGroupedShardingForOtherOperand(
2389 lhs_matching, output_base_shape, other.hlo()->shape(),
2390 other_contracting_partitions, other_non_contracting_partitions,
2391 matching_contracting_partitions,
2392 output_other_non_contracting_partitions, other.sharding(),
2393 output_sharding, partitioned_non_contracting_dims,
2394 lhs_matching ? dims_mapping.rhs_non_contracting_dims
2395 : dims_mapping.lhs_non_contracting_dims,
2396 dims_mapping.contracting_dims);
2397
2398 if (!other_grouped) {
2399 other = other.Replicate();
2400 }
2401 matching = matching.Reshard(UngroupSharding(matching_grouped));
2402 auto per_group_partitioner_state = CreatePerGroupPartitioningState(
2403 matching.state(), matching_grouped.device_groups, b);
2404 top_level_sharding_to_reset.emplace_back(matching.hlo(), matching.sharding());
2405 matching.hlo()->set_sharding(matching_grouped.sharding);
2406 auto matching_p = PartitionedHlo(
2407 matching.hlo(),
2408 GetPerGroupBaseShape(matching_grouped, matching.base_shape()),
2409 per_group_partitioner_state);
2410
2411 auto partially_replicated_other = other.hlo();
2412 if (other_grouped && other_grouped->group_dims.size() == 1 &&
2413 other_grouped->group_dims[0] == other.base_shape().rank()) {
2414 // Group on replication dim.
2415 other = other.Reshard(UngroupSharding(*other_grouped));
2416 partially_replicated_other = other.hlo();
2417 top_level_sharding_to_reset.emplace_back(other.hlo(), other.sharding());
2418 partially_replicated_other->set_sharding(other_grouped->sharding);
2419 } else if (!other.sharding().IsReplicated()) {
2420 HloSharding target_sharding = UngroupSharding(*other_grouped);
2421 GroupedSharding target_group_sharding =
2422 hlo_sharding_util::GroupShardingOnDims(target_sharding,
2423 other_grouped->group_dims);
2424 const bool device_group_match = hlo_sharding_util::DeviceGroupsAreMatch(
2425 target_group_sharding, *other_grouped, /*ignore_group_order=*/false);
2426
2427 // Do not reshard for partial replicate if device group are matched.
2428 // There is a reshard to partial replicate right after this reshard. If
2429 // the device ids within each partial replicate group is the same, no need
2430 // to reshard here.
2431 if (!other.sharding().ReplicateOnLastTileDim() || !device_group_match) {
2432 other = other.Reshard(target_sharding);
2433 }
2434 partially_replicated_other =
2435 other
2436 .Reshard(hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2437 other.sharding(), other_grouped->group_dims))
2438 .hlo();
2439 top_level_sharding_to_reset.emplace_back(
2440 partially_replicated_other, partially_replicated_other->sharding());
2441 partially_replicated_other->set_sharding(other_grouped->sharding);
2442 }
2443 auto other_p = PartitionedHlo(partially_replicated_other, other.base_shape(),
2444 per_group_partitioner_state);
2445 TF_ASSIGN_OR_RETURN(
2446 auto dot,
2447 PartitionDot(lhs_matching ? matching_p : other_p,
2448 lhs_matching ? other_p : matching_p,
2449 GetPerGroupBaseShape(output_grouped, output_base_shape),
2450 output_grouped.sharding, dims_mapping,
2451 num_partitions / matching_grouped.device_groups.size(),
2452 create_sharded_dot, conv_window, module, original_hlo,
2453 options, b, windowed_dot_general_loops, visitor));
2454 return dot;
2455 }
2456
2457 std::pair<HloSharding, HloSharding>
GetDotGroupPartitionContractingOutputShardings(const DotConvDimsMapping & dims_mapping,const GroupedSharding & lhs_grouped,const Shape & output_base_shape,const HloSharding & output_sharding,int64_t group_count,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,int64_t output_batch_partitions,std::vector<int64_t> * output_slice_dims_out,bool * output_replicate_dim_grouped=nullptr)2458 GetDotGroupPartitionContractingOutputShardings(
2459 const DotConvDimsMapping& dims_mapping, const GroupedSharding& lhs_grouped,
2460 const Shape& output_base_shape, const HloSharding& output_sharding,
2461 int64_t group_count, int64_t output_lhs_non_contracting_partitions,
2462 int64_t output_rhs_non_contracting_partitions,
2463 int64_t output_batch_partitions,
2464 std::vector<int64_t>* output_slice_dims_out,
2465 bool* output_replicate_dim_grouped = nullptr) {
2466 HloSharding inner_output_sharding = HloSharding::Replicate();
2467 HloSharding outer_output_tmp_sharding = HloSharding::Replicate();
2468 std::vector<int64_t> output_slice_dims;
2469 if (output_sharding.ReplicateOnLastTileDim() &&
2470 output_sharding.tile_assignment().dimensions().back() % group_count ==
2471 0) {
2472 std::vector<int64_t> group_dim_shards = {
2473 output_sharding.tile_assignment().dimensions().back() / group_count};
2474 auto grouped = AlignGroupsWith(
2475 hlo_sharding_util::GroupShardingOnDims(
2476 output_sharding,
2477 {output_sharding.tile_assignment().num_dimensions() - 1},
2478 group_dim_shards),
2479 lhs_grouped,
2480 /*ignore_group_order=*/true);
2481 outer_output_tmp_sharding = UngroupSharding(grouped);
2482 inner_output_sharding = std::move(grouped.sharding);
2483 } else {
2484 if (auto found_dims = FindMatchingPartitionedDimsForGrouping(
2485 output_sharding, lhs_grouped.device_groups)) {
2486 output_slice_dims = std::move(*found_dims);
2487 if (!output_slice_dims.empty()) {
2488 // FindMatchingPartitionedDimsForGrouping already makes sure the groups
2489 // are compatible with LHS/RHS. We avoid AlignGroupsWith/UngroupSharding
2490 // because that could change the group order causing a reshard with
2491 // collective-permute, which is unnecessary since these groups will be
2492 // all-reduced upon anyway for contracting-dim sharding.
2493 auto grouped = hlo_sharding_util::GroupShardingOnDims(
2494 output_sharding, output_slice_dims);
2495 inner_output_sharding = grouped.sharding;
2496 outer_output_tmp_sharding = output_sharding;
2497 }
2498 } else if (output_lhs_non_contracting_partitions == group_count ||
2499 output_rhs_non_contracting_partitions == group_count ||
2500 output_batch_partitions == group_count) {
2501 if (output_lhs_non_contracting_partitions == group_count) {
2502 for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
2503 output_slice_dims.push_back(dim.output);
2504 }
2505 } else if (output_rhs_non_contracting_partitions == group_count) {
2506 for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
2507 output_slice_dims.push_back(dim.output);
2508 }
2509 } else {
2510 for (const auto& dim : dims_mapping.batch_dims) {
2511 output_slice_dims.push_back(dim.output);
2512 }
2513 }
2514 if (!output_slice_dims.empty()) {
2515 auto grouped = AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(
2516 output_sharding, output_slice_dims),
2517 lhs_grouped);
2518 inner_output_sharding = grouped.sharding;
2519 outer_output_tmp_sharding = UngroupSharding(grouped);
2520 }
2521 }
2522 }
2523 if (output_replicate_dim_grouped) {
2524 *output_replicate_dim_grouped =
2525 absl::c_linear_search(output_slice_dims, output_base_shape.rank());
2526 }
2527 if (output_slice_dims_out) {
2528 if (output_sharding.ReplicateOnLastTileDim()) {
2529 // Remove the replication group dim.
2530 output_slice_dims.erase(
2531 std::remove_if(
2532 output_slice_dims.begin(), output_slice_dims.end(),
2533 [&](int64_t dim) { return dim == output_base_shape.rank(); }),
2534 output_slice_dims.end());
2535 }
2536 (*output_slice_dims_out) = std::move(output_slice_dims);
2537 }
2538 return std::make_pair(inner_output_sharding, outer_output_tmp_sharding);
2539 }
2540
2541 std::pair<HloSharding, HloSharding>
GetDotGroupPartitionContractingLhsRhsShardings(const PartitionedHlo & lhs,const PartitionedHlo & rhs,absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_contracting_dims)2542 GetDotGroupPartitionContractingLhsRhsShardings(
2543 const PartitionedHlo& lhs, const PartitionedHlo& rhs,
2544 absl::Span<const DotConvDimsMapping::DimsMapping>
2545 partitioned_contracting_dims) {
2546 HloSharding lhs_sharding = lhs.sharding();
2547 HloSharding rhs_sharding = rhs.sharding();
2548 std::vector<int64_t> lhs_tile_shape =
2549 lhs_sharding.tile_assignment().dimensions();
2550 std::vector<int64_t> rhs_tile_shape =
2551 rhs_sharding.tile_assignment().dimensions();
2552 if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) >
2553 ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) {
2554 for (const auto& dim : partitioned_contracting_dims) {
2555 rhs_tile_shape[dim.rhs] = lhs_tile_shape[dim.lhs];
2556 }
2557 auto new_tile = rhs.sharding().tile_assignment();
2558 new_tile.Reshape(rhs_tile_shape);
2559 rhs_sharding = rhs_sharding.ReplicateOnLastTileDim()
2560 ? HloSharding::PartialTile(new_tile)
2561 : HloSharding::Tile(new_tile);
2562 } else {
2563 for (const auto& dim : partitioned_contracting_dims) {
2564 lhs_tile_shape[dim.lhs] = rhs_tile_shape[dim.rhs];
2565 }
2566 auto new_tile = lhs.sharding().tile_assignment();
2567 new_tile.Reshape(lhs_tile_shape);
2568 lhs_sharding = lhs_sharding.ReplicateOnLastTileDim()
2569 ? HloSharding::PartialTile(new_tile)
2570 : HloSharding::Tile(new_tile);
2571 }
2572 return std::make_pair(lhs_sharding, rhs_sharding);
2573 }
2574
PartitionDotGroupOnContracting(PartitionedHlo lhs,PartitionedHlo rhs,absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_contracting_dims,int64_t output_batch_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops,SpmdPartitioningVisitor * visitor)2575 StatusOr<HloInstruction*> PartitionDotGroupOnContracting(
2576 PartitionedHlo lhs, PartitionedHlo rhs,
2577 absl::Span<const DotConvDimsMapping::DimsMapping>
2578 partitioned_contracting_dims,
2579 int64_t output_batch_partitions,
2580 int64_t output_lhs_non_contracting_partitions,
2581 int64_t output_rhs_non_contracting_partitions,
2582 const Shape& output_base_shape, const HloSharding& output_sharding,
2583 const DotConvDimsMapping& dims_mapping, int64_t num_partitions,
2584 const std::function<StatusOr<HloInstruction*>(
2585 HloInstruction*, HloInstruction*, SpmdBuilder*,
2586 const Window& conv_window)>& create_sharded_dot,
2587 const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
2588 bool require_matching_devices_to_group,
2589 const SpmdPartitionerOptions& options, SpmdBuilder* b,
2590 std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
2591 windowed_dot_general_loops,
2592 SpmdPartitioningVisitor* visitor) {
2593 std::vector<std::pair<HloInstruction*, HloSharding>>
2594 top_level_sharding_to_reset;
2595 absl::Cleanup cleaner = [&] {
2596 for (auto& to_reset : top_level_sharding_to_reset) {
2597 to_reset.first->set_sharding(to_reset.second);
2598 }
2599 };
2600 std::vector<int64_t> lhs_dims;
2601 std::vector<int64_t> rhs_dims;
2602 int64_t group_count = 1;
2603 for (const auto& dim : partitioned_contracting_dims) {
2604 lhs_dims.push_back(dim.lhs);
2605 rhs_dims.push_back(dim.rhs);
2606 group_count *= lhs.sharding().tile_assignment().dim(dim.lhs);
2607 }
2608 HloSharding lhs_sharding = HloSharding::Replicate();
2609 HloSharding rhs_sharding = HloSharding::Replicate();
2610 std::tie(lhs_sharding, rhs_sharding) =
2611 GetDotGroupPartitionContractingLhsRhsShardings(
2612 lhs, rhs, partitioned_contracting_dims);
2613 auto lhs_grouped =
2614 hlo_sharding_util::GroupShardingOnDims(lhs_sharding, lhs_dims);
2615 auto rhs_grouped =
2616 hlo_sharding_util::GroupShardingOnDims(rhs_sharding, rhs_dims);
2617 if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) >
2618 ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) {
2619 rhs_grouped = AlignGroupsWith(rhs_grouped, lhs_grouped);
2620 rhs_sharding = UngroupSharding(rhs_grouped);
2621 if (require_matching_devices_to_group && rhs.sharding() != rhs_sharding) {
2622 return nullptr;
2623 }
2624 rhs = rhs.Reshard(rhs_sharding);
2625 } else {
2626 lhs_grouped = AlignGroupsWith(lhs_grouped, rhs_grouped);
2627 lhs_sharding = UngroupSharding(lhs_grouped);
2628 if (require_matching_devices_to_group && lhs.sharding() != lhs_sharding) {
2629 return nullptr;
2630 }
2631 lhs = lhs.Reshard(lhs_sharding);
2632 }
2633 // Mask out invalid data.
2634 std::vector<int64_t> lhs_skipped_dims;
2635 for (int64_t i = 0; i < lhs.base_shape().rank(); ++i) {
2636 if (absl::c_linear_search(lhs_dims, i)) {
2637 continue;
2638 }
2639 lhs_skipped_dims.push_back(i);
2640 }
2641 lhs = lhs.PadWithZero(
2642 /*left_padded_dims=*/{}, lhs_skipped_dims);
2643 std::vector<int64_t> rhs_skipped_dims;
2644 for (int64_t i = 0; i < rhs.base_shape().rank(); ++i) {
2645 if (absl::c_linear_search(rhs_dims, i)) {
2646 continue;
2647 }
2648 rhs_skipped_dims.push_back(i);
2649 }
2650 rhs = rhs.PadWithZero(
2651 /*left_padded_dims=*/{}, rhs_skipped_dims);
2652 top_level_sharding_to_reset.emplace_back(lhs.hlo(), lhs_sharding);
2653 lhs.hlo()->set_sharding(lhs_grouped.sharding);
2654 top_level_sharding_to_reset.emplace_back(rhs.hlo(), rhs_sharding);
2655 rhs.hlo()->set_sharding(rhs_grouped.sharding);
2656
2657 HloSharding inner_output_sharding = HloSharding::Replicate();
2658 HloSharding outer_output_tmp_sharding = HloSharding::Replicate();
2659 std::vector<int64_t> output_slice_dims;
2660 bool output_replicate_dim_grouped;
2661 std::tie(inner_output_sharding, outer_output_tmp_sharding) =
2662 GetDotGroupPartitionContractingOutputShardings(
2663 dims_mapping, lhs_grouped, output_base_shape, output_sharding,
2664 group_count, output_lhs_non_contracting_partitions,
2665 output_rhs_non_contracting_partitions, output_batch_partitions,
2666 &output_slice_dims, &output_replicate_dim_grouped);
2667 Shape inner_output_base_shape = output_base_shape;
2668 auto get_non_slice_dims = [&] {
2669 std::vector<int64_t> non_group_dims;
2670 for (int64_t i = 0; i < output_base_shape.rank(); ++i) {
2671 if (!absl::c_linear_search(output_slice_dims, i)) {
2672 non_group_dims.push_back(i);
2673 }
2674 }
2675 return non_group_dims;
2676 };
2677 if (!output_slice_dims.empty()) {
2678 inner_output_base_shape = MakePartitionedShape(
2679 output_base_shape,
2680 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2681 output_sharding, get_non_slice_dims()));
2682 }
2683 std::function<StatusOr<HloInstruction*>(HloInstruction*, HloInstruction*,
2684 SpmdBuilder*, const Window&)>
2685 inner_creator =
2686 [&](HloInstruction* l, HloInstruction* r, SpmdBuilder* b,
2687 const Window& conv_window) -> StatusOr<HloInstruction*> {
2688 TF_ASSIGN_OR_RETURN(auto inner_dot,
2689 create_sharded_dot(l, r, b, conv_window));
2690 HloInstruction* result = inner_dot;
2691 if (!output_slice_dims.empty()) {
2692 // Create an AllReduce along slice dims first to allow a reduce-scatter.
2693 result = lhs.state().partitioner->AllReduceAlongShardingDims(
2694 b, result, outer_output_tmp_sharding, lhs.state().next_channel_id,
2695 output_slice_dims, lhs.state().collective_ops_creator,
2696 MakeBinaryAdd(output_base_shape.element_type(), module));
2697 // Use resharding to slice the output. Use a temporary reshard cache since
2698 // we are faking with replicated sharding.
2699 PartitionedHlo::PartitioningState new_state = lhs.state();
2700 new_state.b = b;
2701 new_state.partition_id =
2702 lhs.state().collective_ops_creator.create_partition_id(b);
2703 PartitionedHlo::ReshardCache tmp_cache;
2704 new_state.reshard_cache = &tmp_cache;
2705 result->set_sharding(HloSharding::Replicate());
2706 result =
2707 PartitionedHlo(result, result->shape(), new_state)
2708 .Reshard(hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2709 outer_output_tmp_sharding, get_non_slice_dims()))
2710 .hlo();
2711 // If the output has partial replication, and the last tile dim is used
2712 // for grouping, we need to do a separate allreduce after reduce-scatter.
2713 if (output_replicate_dim_grouped) {
2714 result = lhs.state().partitioner->AllReduceAlongShardingDims(
2715 b, result, outer_output_tmp_sharding, lhs.state().next_channel_id,
2716 {output_base_shape.rank()}, lhs.state().collective_ops_creator,
2717 MakeBinaryAdd(output_base_shape.element_type(), module));
2718 }
2719 } else {
2720 result = lhs.state().partitioner->AllReduceAlongShardingDims(
2721 b, result, lhs_sharding, lhs.state().next_channel_id, lhs_dims,
2722 lhs.state().collective_ops_creator,
2723 MakeBinaryAdd(output_base_shape.element_type(), module));
2724 }
2725 return result;
2726 };
2727
2728 PartitionedHlo::PartitioningState inner_state =
2729 CreatePerGroupPartitioningState(lhs.state(), lhs_grouped.device_groups,
2730 b);
2731
2732 HloInstruction* maybe_windowed_dot = nullptr;
2733
2734 // Tentatively disables the inner reshard when the "faster windowed einsum"
2735 // flag is enabled, because the windowed einsum implementation is currently
2736 // slow with this kind of reshard happening.
2737 int original_num_windowed_loops = windowed_dot_general_loops->size();
2738 if (options.choose_faster_windowed_einsum_over_mem) {
2739 Shape predicted_inner_output_base_shape = output_base_shape;
2740 auto predicted_inner_creator = create_sharded_dot;
2741 TF_ASSIGN_OR_RETURN(
2742 maybe_windowed_dot,
2743 PartitionDot(
2744 PartitionedHlo(lhs.hlo(),
2745 GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()),
2746 inner_state),
2747 PartitionedHlo(rhs.hlo(),
2748 GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()),
2749 inner_state),
2750 predicted_inner_output_base_shape, inner_output_sharding,
2751 dims_mapping, num_partitions / group_count, predicted_inner_creator,
2752 conv_window, module, original_hlo, options, b,
2753 windowed_dot_general_loops, visitor));
2754 }
2755 int new_num_windowed_loops = windowed_dot_general_loops->size();
2756
2757 TF_ASSIGN_OR_RETURN(
2758 auto inner_dot,
2759 PartitionDot(
2760 PartitionedHlo(lhs.hlo(),
2761 GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()),
2762 inner_state),
2763 PartitionedHlo(rhs.hlo(),
2764 GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()),
2765 inner_state),
2766 inner_output_base_shape, inner_output_sharding, dims_mapping,
2767 num_partitions / group_count, inner_creator, conv_window, module,
2768 original_hlo, options, b, windowed_dot_general_loops, visitor));
2769
2770 // Reenables the inner reshard if there is an inner dot and no actual
2771 // windowed_dot_general_loops generated.
2772 if (inner_dot && (new_num_windowed_loops == original_num_windowed_loops)) {
2773 maybe_windowed_dot = inner_dot;
2774 } else if (maybe_windowed_dot) {
2775 if (options.choose_faster_windowed_einsum_over_mem) {
2776 HloInstruction* ar = lhs.state().partitioner->AllReduceAlongShardingDims(
2777 b, maybe_windowed_dot, lhs_sharding, lhs.state().next_channel_id,
2778 lhs_dims, lhs.state().collective_ops_creator,
2779 MakeBinaryAdd(output_base_shape.element_type(), module));
2780 maybe_windowed_dot = ar;
2781 outer_output_tmp_sharding =
2782 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2783 outer_output_tmp_sharding, output_slice_dims);
2784 }
2785 } else {
2786 return nullptr;
2787 }
2788
2789 maybe_windowed_dot->set_sharding(outer_output_tmp_sharding);
2790 auto d = PartitionedHlo(maybe_windowed_dot, output_base_shape, lhs.state())
2791 .Reshard(output_sharding)
2792 .hlo();
2793 return d;
2794 }
2795
ConvertDimsMappingWithFeatureGroupCount(const DotConvDimsMapping & dims_mapping,HloInstruction * original_hlo)2796 DotConvDimsMapping ConvertDimsMappingWithFeatureGroupCount(
2797 const DotConvDimsMapping& dims_mapping, HloInstruction* original_hlo) {
2798 const auto& dnums = original_hlo->convolution_dimension_numbers();
2799 DotConvDimsMapping new_dims_mapping;
2800 new_dims_mapping.batch_dims = dims_mapping.batch_dims;
2801 new_dims_mapping.conv_spatial_dims = dims_mapping.conv_spatial_dims;
2802 // Append batch dims.
2803 new_dims_mapping.batch_dims.emplace_back();
2804 new_dims_mapping.batch_dims.back().lhs = dnums.input_feature_dimension();
2805 new_dims_mapping.batch_dims.back().rhs =
2806 dnums.kernel_output_feature_dimension();
2807 new_dims_mapping.batch_dims.back().output = dnums.output_feature_dimension();
2808 new_dims_mapping.batch_dims.back().spatial = -1;
2809 // Setup non contracting dims.
2810 new_dims_mapping.lhs_non_contracting_dims.emplace_back();
2811 new_dims_mapping.lhs_non_contracting_dims.back().lhs =
2812 dnums.input_batch_dimension();
2813 new_dims_mapping.rhs_non_contracting_dims.emplace_back();
2814 new_dims_mapping.rhs_non_contracting_dims.back().rhs =
2815 dnums.kernel_input_feature_dimension();
2816 return new_dims_mapping;
2817 }
2818
ConvertDimsMappingWithBatchGroupCount(const DotConvDimsMapping & dims_mapping,HloInstruction * original_hlo)2819 DotConvDimsMapping ConvertDimsMappingWithBatchGroupCount(
2820 const DotConvDimsMapping& dims_mapping, HloInstruction* original_hlo) {
2821 const auto& dnums = original_hlo->convolution_dimension_numbers();
2822 DotConvDimsMapping new_dims_mapping;
2823 new_dims_mapping.batch_dims = dims_mapping.batch_dims;
2824 new_dims_mapping.conv_spatial_dims = dims_mapping.conv_spatial_dims;
2825 new_dims_mapping.contracting_dims = dims_mapping.contracting_dims;
2826 // Append batch dims.
2827 new_dims_mapping.batch_dims.emplace_back();
2828 new_dims_mapping.batch_dims.back().lhs = dnums.input_batch_dimension();
2829 new_dims_mapping.batch_dims.back().rhs =
2830 dnums.kernel_output_feature_dimension();
2831 new_dims_mapping.batch_dims.back().output = dnums.output_feature_dimension();
2832 new_dims_mapping.batch_dims.back().spatial = -1;
2833 return new_dims_mapping;
2834 }
2835
2836 // Estimate the number of iterations of a subsequent windowed einsum
2837 // partitioning if its partitioned in the non-contracting dimensions.
2838 // First value returned is the estimate of the number of iterations if LHS is
2839 // matched while the second is the number of iterations if RHS is matched.
2840 std::pair<std::optional<int64_t>, std::optional<int64_t>>
EstimateWindowedEinsumIterationsForNonContractingPartitioning(const DotConvDimsMapping & dims_mapping,const PartitionedHlo & lhs,const PartitionedHlo & rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const SpmdPartitionerOptions & options,int64_t num_partitions,int64_t lhs_non_contracting_partitions,int64_t rhs_non_contracting_partitions,int64_t lhs_matching_partitions,int64_t rhs_matching_partitions,int64_t lhs_contracting_partitions,int64_t rhs_contracting_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,int64_t lhs_batch_partitions,int64_t rhs_batch_partitions,const Window & conv_window)2841 EstimateWindowedEinsumIterationsForNonContractingPartitioning(
2842 const DotConvDimsMapping& dims_mapping, const PartitionedHlo& lhs,
2843 const PartitionedHlo& rhs, const Shape& output_base_shape,
2844 const HloSharding& output_sharding, const SpmdPartitionerOptions& options,
2845 int64_t num_partitions, int64_t lhs_non_contracting_partitions,
2846 int64_t rhs_non_contracting_partitions, int64_t lhs_matching_partitions,
2847 int64_t rhs_matching_partitions, int64_t lhs_contracting_partitions,
2848 int64_t rhs_contracting_partitions,
2849 int64_t output_lhs_non_contracting_partitions,
2850 int64_t output_rhs_non_contracting_partitions, int64_t lhs_batch_partitions,
2851 int64_t rhs_batch_partitions, const Window& conv_window) {
2852 const DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping(
2853 dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(),
2854 output_base_shape.rank());
2855 auto subsequent_einsum_iterations_estimate =
2856 [&](bool assume_lhs_match) -> std::optional<int64_t> {
2857 const std::vector<DotConvDimsMapping::DimsMapping>&
2858 matching_non_contracting_dims =
2859 assume_lhs_match ? dims_mapping.lhs_non_contracting_dims
2860 : dims_mapping.rhs_non_contracting_dims;
2861 const std::vector<DotConvDimsMapping::DimsMapping>&
2862 other_non_contracting_dims =
2863 assume_lhs_match ? dims_mapping.rhs_non_contracting_dims
2864 : dims_mapping.lhs_non_contracting_dims;
2865 const std::vector<int64_t>& output_to_matching_indices =
2866 assume_lhs_match ? indices_map.output_to_lhs_indices
2867 : indices_map.output_to_rhs_indices;
2868 const std::vector<int64_t>& output_to_other_indices =
2869 assume_lhs_match ? indices_map.output_to_rhs_indices
2870 : indices_map.output_to_lhs_indices;
2871 const std::vector<int64_t>& matching_to_output_indices =
2872 assume_lhs_match ? indices_map.lhs_to_output_indices
2873 : indices_map.rhs_to_output_indices;
2874 const std::vector<int64_t>& other_to_output_indices =
2875 assume_lhs_match ? indices_map.rhs_to_output_indices
2876 : indices_map.lhs_to_output_indices;
2877 const HloSharding& matching_sharding =
2878 assume_lhs_match ? lhs.sharding() : rhs.sharding();
2879 const HloSharding& other_sharding =
2880 assume_lhs_match ? rhs.sharding() : lhs.sharding();
2881 const PartitionedHlo& matching_partitioned = assume_lhs_match ? lhs : rhs;
2882 const PartitionedHlo& other_partitioned = assume_lhs_match ? rhs : lhs;
2883 const int64_t matching_non_contracting_partitions =
2884 assume_lhs_match ? lhs_non_contracting_partitions
2885 : rhs_non_contracting_partitions;
2886 const int64_t other_non_contracting_partitions =
2887 assume_lhs_match ? rhs_non_contracting_partitions
2888 : lhs_non_contracting_partitions;
2889 const int64_t matching_contracting_partitions =
2890 assume_lhs_match ? lhs_contracting_partitions
2891 : rhs_contracting_partitions;
2892 const int64_t other_contracting_partitions =
2893 assume_lhs_match ? rhs_contracting_partitions
2894 : lhs_contracting_partitions;
2895 const int64_t output_matching_non_contracting_partitions =
2896 assume_lhs_match ? output_lhs_non_contracting_partitions
2897 : output_rhs_non_contracting_partitions;
2898 const int64_t output_other_non_contracting_partitions =
2899 assume_lhs_match ? output_rhs_non_contracting_partitions
2900 : output_lhs_non_contracting_partitions;
2901 const int64_t matching_batch_partitions =
2902 assume_lhs_match ? lhs_batch_partitions : rhs_batch_partitions;
2903 const int64_t other_batch_partitions =
2904 assume_lhs_match ? rhs_batch_partitions : lhs_batch_partitions;
2905 const int64_t matching_matched_non_contracting_partitions =
2906 assume_lhs_match ? lhs_non_contracting_partitions
2907 : rhs_non_contracting_partitions;
2908 std::vector<int64_t> output_dims;
2909 output_dims.reserve(matching_non_contracting_dims.size());
2910 for (const DotConvDimsMapping::DimsMapping& dim :
2911 matching_non_contracting_dims) {
2912 output_dims.push_back(dim.output);
2913 }
2914 GroupedSharding output_grouped =
2915 hlo_sharding_util::GroupShardingOnDims(output_sharding, output_dims);
2916 GroupedSharding matching_grouped =
2917 GetNonContractingPartitionGroupedShardingForMatchedOperand(
2918 assume_lhs_match, matching_sharding, output_sharding,
2919 matching_non_contracting_dims);
2920 std::optional<GroupedSharding> other_grouped =
2921 GetNonContractingPartitionGroupedShardingForOtherOperand(
2922 assume_lhs_match, output_base_shape,
2923 other_partitioned.hlo()->shape(), other_contracting_partitions,
2924 other_non_contracting_partitions, matching_contracting_partitions,
2925 output_other_non_contracting_partitions, other_sharding,
2926 output_sharding, matching_non_contracting_dims,
2927 other_non_contracting_dims, dims_mapping.contracting_dims);
2928 if (!other_grouped) {
2929 return std::nullopt;
2930 }
2931 std::optional<HloSharding> output_sharding_transposed_to_match_matching =
2932 hlo_sharding_util::TransposeShardingWithCollapsedDims(
2933 output_grouped.sharding, output_to_matching_indices,
2934 matching_to_output_indices);
2935 std::optional<HloSharding> output_sharding_transposed_to_match_other =
2936 hlo_sharding_util::TransposeShardingWithCollapsedDims(
2937 output_grouped.sharding, output_to_other_indices,
2938 other_to_output_indices);
2939 auto lhs_sharding_transposed_to_match_rhs =
2940 hlo_sharding_util::TransposeShardingWithCollapsedDims(
2941 lhs.sharding(), indices_map.lhs_to_rhs_indices,
2942 indices_map.rhs_to_lhs_indices);
2943 auto rhs_sharding_transposed_to_match_lhs =
2944 hlo_sharding_util::TransposeShardingWithCollapsedDims(
2945 rhs.sharding(), indices_map.rhs_to_lhs_indices,
2946 indices_map.lhs_to_rhs_indices);
2947 const int64_t new_num_partitions =
2948 num_partitions / matching_non_contracting_partitions;
2949 std::optional<WindowedEinsumConfig> e_config =
2950 GetWindowedEinsumConfiguration(
2951 new_num_partitions, output_matching_non_contracting_partitions,
2952 output_other_non_contracting_partitions,
2953 other_contracting_partitions, other_non_contracting_partitions,
2954 other_batch_partitions, matching_contracting_partitions,
2955 matching_non_contracting_partitions /
2956 matching_matched_non_contracting_partitions,
2957 matching_batch_partitions,
2958 ShapeSizeInBytes(other_partitioned.base_shape()),
2959 ShapeSizeInBytes(matching_partitioned.base_shape()) /
2960 matching_non_contracting_partitions,
2961 ShapeSizeInBytes(
2962 GetPerGroupBaseShape(output_grouped, output_base_shape)),
2963 options, output_sharding_transposed_to_match_matching,
2964 output_sharding_transposed_to_match_other,
2965 lhs_sharding_transposed_to_match_rhs,
2966 rhs_sharding_transposed_to_match_lhs, matching_grouped.sharding,
2967 other_grouped->sharding, conv_window, dims_mapping);
2968 return e_config ? new_num_partitions : std::optional<int64_t>(std::nullopt);
2969 };
2970 std::optional<int64_t> lhs_matching_iterations;
2971 if (lhs_matching_partitions != 0) {
2972 lhs_matching_iterations = subsequent_einsum_iterations_estimate(true);
2973 }
2974 std::optional<int64_t> rhs_matching_iterations;
2975 if (rhs_matching_partitions != 0) {
2976 rhs_matching_iterations = subsequent_einsum_iterations_estimate(false);
2977 }
2978 return std::make_pair(lhs_matching_iterations, rhs_matching_iterations);
2979 }
2980
2981 // Return if we should prioritize partitioning in the contracting dimensions
2982 // first then non-contracting dimensions if we estimate that would be faster.
2983 // The general idea is similar as the one in
2984 // LhsIsBestMatchForNonContractingPartitioning with one all-gather replaced by
2985 // reduce-scatter.
PrioritizeContractingDimensionsPartitioning(const DotConvDimsMapping & dims_mapping,const PartitionedHlo & lhs,const PartitionedHlo & rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const SpmdPartitionerOptions & options,int64_t num_partitions,int64_t lhs_non_contracting_partitions,int64_t rhs_non_contracting_partitions,int64_t lhs_contracting_partitions,int64_t rhs_contracting_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,int64_t lhs_batch_partitions,int64_t rhs_batch_partitions,int64_t output_batch_partitions,bool require_matching_devices_to_group,SpmdBuilder * b,const Window & conv_window,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,SpmdPartitioningVisitor * visitor)2986 bool PrioritizeContractingDimensionsPartitioning(
2987 const DotConvDimsMapping& dims_mapping, const PartitionedHlo& lhs,
2988 const PartitionedHlo& rhs, const Shape& output_base_shape,
2989 const HloSharding& output_sharding, const SpmdPartitionerOptions& options,
2990 int64_t num_partitions, int64_t lhs_non_contracting_partitions,
2991 int64_t rhs_non_contracting_partitions, int64_t lhs_contracting_partitions,
2992 int64_t rhs_contracting_partitions,
2993 int64_t output_lhs_non_contracting_partitions,
2994 int64_t output_rhs_non_contracting_partitions, int64_t lhs_batch_partitions,
2995 int64_t rhs_batch_partitions, int64_t output_batch_partitions,
2996 bool require_matching_devices_to_group, SpmdBuilder* b,
2997 const Window& conv_window,
2998 const std::function<StatusOr<HloInstruction*>(
2999 HloInstruction*, HloInstruction*, SpmdBuilder*,
3000 const Window& conv_window)>& create_sharded_dot,
3001 SpmdPartitioningVisitor* visitor) {
3002 const bool may_group_on_lhs_non_contracting =
3003 lhs_non_contracting_partitions == output_lhs_non_contracting_partitions &&
3004 lhs_non_contracting_partitions > 1;
3005 const bool may_group_on_rhs_non_contracting =
3006 rhs_non_contracting_partitions == output_rhs_non_contracting_partitions &&
3007 rhs_non_contracting_partitions > 1;
3008 if (!options.choose_faster_windowed_einsum_over_mem) {
3009 return false;
3010 }
3011 // Check only for perfect dimensions match for now.
3012 if (!may_group_on_lhs_non_contracting && !may_group_on_rhs_non_contracting) {
3013 return false;
3014 }
3015 std::optional<int64_t> lhs_matching_iterations;
3016 std::optional<int64_t> rhs_matching_iterations;
3017 const int64_t lhs_matching_non_contracting_partitions =
3018 may_group_on_lhs_non_contracting ? lhs_non_contracting_partitions : 0;
3019 const int64_t rhs_matching_non_contracting_partitions =
3020 may_group_on_rhs_non_contracting ? rhs_non_contracting_partitions : 0;
3021 std::tie(lhs_matching_iterations, rhs_matching_iterations) =
3022 EstimateWindowedEinsumIterationsForNonContractingPartitioning(
3023 dims_mapping, lhs, rhs, output_base_shape, output_sharding, options,
3024 num_partitions, lhs_non_contracting_partitions,
3025 rhs_non_contracting_partitions,
3026 lhs_matching_non_contracting_partitions,
3027 rhs_matching_non_contracting_partitions, lhs_contracting_partitions,
3028 rhs_contracting_partitions, output_lhs_non_contracting_partitions,
3029 output_rhs_non_contracting_partitions, lhs_batch_partitions,
3030 rhs_batch_partitions, conv_window);
3031 if (!lhs_matching_iterations && !rhs_matching_iterations) {
3032 return false;
3033 }
3034 // Be conservative and handle only case where the two partitions in rhs and
3035 // lhs match
3036 if (!(lhs_contracting_partitions == rhs_contracting_partitions &&
3037 lhs_contracting_partitions > 1)) {
3038 return false;
3039 }
3040 // Estimate the iterations in the case we perform the partitioning on the
3041 // contracting dimensions instead.
3042 std::vector<int64_t> lhs_dims;
3043 std::vector<int64_t> rhs_dims;
3044 int64_t group_count = 1;
3045 for (const auto& dim : dims_mapping.contracting_dims) {
3046 lhs_dims.push_back(dim.lhs);
3047 rhs_dims.push_back(dim.rhs);
3048 group_count *= lhs.sharding().tile_assignment().dim(dim.lhs);
3049 }
3050 HloSharding lhs_sharding = HloSharding::Replicate();
3051 HloSharding rhs_sharding = HloSharding::Replicate();
3052 std::tie(lhs_sharding, rhs_sharding) =
3053 GetDotGroupPartitionContractingLhsRhsShardings(
3054 lhs, rhs, dims_mapping.contracting_dims);
3055 auto lhs_grouped =
3056 hlo_sharding_util::GroupShardingOnDims(lhs_sharding, lhs_dims);
3057 auto rhs_grouped =
3058 hlo_sharding_util::GroupShardingOnDims(rhs_sharding, rhs_dims);
3059 rhs_grouped = AlignGroupsWith(rhs_grouped, lhs_grouped);
3060 rhs_sharding = UngroupSharding(rhs_grouped);
3061
3062 if (require_matching_devices_to_group && rhs.sharding() != rhs_sharding) {
3063 return false;
3064 }
3065 const int64_t new_num_partitions =
3066 num_partitions / lhs_contracting_partitions;
3067
3068 HloSharding inner_output_sharding = HloSharding::Replicate();
3069 HloSharding outer_output_tmp_sharding = HloSharding::Replicate();
3070 std::vector<int64_t> output_slice_dims;
3071 std::tie(inner_output_sharding, outer_output_tmp_sharding) =
3072 GetDotGroupPartitionContractingOutputShardings(
3073 dims_mapping, lhs_grouped, output_base_shape, output_sharding,
3074 group_count, output_lhs_non_contracting_partitions,
3075 output_rhs_non_contracting_partitions, output_batch_partitions,
3076 &output_slice_dims);
3077 Shape inner_output_base_shape = output_base_shape;
3078 if (!output_slice_dims.empty()) {
3079 std::vector<int64_t> non_group_dims;
3080 for (int64_t i = 0; i < output_base_shape.rank(); ++i) {
3081 if (!absl::c_linear_search(output_slice_dims, i)) {
3082 non_group_dims.push_back(i);
3083 }
3084 }
3085 inner_output_base_shape = MakePartitionedShape(
3086 output_base_shape,
3087 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
3088 output_sharding, non_group_dims));
3089 }
3090 int64_t new_output_lhs_non_contracting_partitions = 1;
3091 int64_t new_output_rhs_non_contracting_partitions = 1;
3092 if (!inner_output_sharding.IsTileMaximal()) {
3093 for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
3094 new_output_lhs_non_contracting_partitions *=
3095 inner_output_sharding.tile_assignment().dim(dim.output);
3096 }
3097 for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
3098 if (dim.output != -1) {
3099 new_output_rhs_non_contracting_partitions *=
3100 inner_output_sharding.tile_assignment().dim(dim.output);
3101 }
3102 }
3103 }
3104
3105 const DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping(
3106 dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(),
3107 inner_output_base_shape.rank());
3108 std::optional<HloSharding> output_sharding_transposed_to_match_lhs =
3109 hlo_sharding_util::TransposeShardingWithCollapsedDims(
3110 inner_output_sharding, indices_map.output_to_lhs_indices,
3111 indices_map.lhs_to_output_indices);
3112 std::optional<HloSharding> output_sharding_transposed_to_match_rhs =
3113 hlo_sharding_util::TransposeShardingWithCollapsedDims(
3114 inner_output_sharding, indices_map.output_to_rhs_indices,
3115 indices_map.rhs_to_output_indices);
3116 auto lhs_sharding_transposed_to_match_rhs =
3117 hlo_sharding_util::TransposeShardingWithCollapsedDims(
3118 lhs_sharding, indices_map.lhs_to_rhs_indices,
3119 indices_map.rhs_to_lhs_indices);
3120 auto rhs_sharding_transposed_to_match_lhs =
3121 hlo_sharding_util::TransposeShardingWithCollapsedDims(
3122 rhs_sharding, indices_map.rhs_to_lhs_indices,
3123 indices_map.lhs_to_rhs_indices);
3124 std::optional<WindowedEinsumConfig> e_config = GetWindowedEinsumConfiguration(
3125 new_num_partitions, new_output_lhs_non_contracting_partitions,
3126 new_output_rhs_non_contracting_partitions, 1,
3127 rhs_non_contracting_partitions, rhs_batch_partitions, 1,
3128 lhs_non_contracting_partitions, lhs_batch_partitions,
3129 ShapeSizeInBytes(GetPerGroupBaseShape(rhs_grouped, rhs.base_shape())),
3130 ShapeSizeInBytes(GetPerGroupBaseShape(lhs_grouped, lhs.base_shape())),
3131 ShapeSizeInBytes(inner_output_base_shape), options,
3132 output_sharding_transposed_to_match_lhs,
3133 output_sharding_transposed_to_match_rhs,
3134 lhs_sharding_transposed_to_match_rhs,
3135 rhs_sharding_transposed_to_match_lhs, lhs_grouped.sharding,
3136 rhs_grouped.sharding, conv_window, dims_mapping);
3137 if (!e_config) {
3138 return false;
3139 }
3140
3141 int64_t num_iterations = lhs_matching_iterations ? *lhs_matching_iterations
3142 : *rhs_matching_iterations;
3143 HloInstruction* other_hlo = lhs_matching_iterations ? rhs.hlo() : lhs.hlo();
3144 auto other_non_contracting_dims = lhs_matching_iterations
3145 ? dims_mapping.rhs_non_contracting_dims
3146 : dims_mapping.lhs_non_contracting_dims;
3147 auto other_sharding =
3148 lhs_matching_iterations ? rhs.sharding() : lhs.sharding();
3149 auto other_grouped = lhs_matching_iterations ? rhs_grouped : lhs_grouped;
3150 Shape other_base_shape =
3151 lhs_matching_iterations ? rhs.base_shape() : lhs.base_shape();
3152
3153 const int64_t all_gather_bytes =
3154 ShapeUtil::ByteSizeOf(other_hlo->shape()) * new_num_partitions;
3155 const int64_t reduce_scatter_bytes =
3156 ShapeUtil::ByteSizeOf(inner_output_base_shape) / new_num_partitions *
3157 num_iterations;
3158 std::vector<int64_t> ag_replication_dims;
3159 ag_replication_dims.reserve(other_non_contracting_dims.size());
3160 for (const DotConvDimsMapping::DimsMapping& dim :
3161 other_non_contracting_dims) {
3162 ag_replication_dims.push_back(lhs_matching_iterations ? dim.rhs : dim.lhs);
3163 }
3164 auto all_gather_subgroups =
3165 GetPartitionGroupsForReplication(other_sharding, ag_replication_dims);
3166 auto reduce_scatter_subgroups = GetPartitionGroupsForReplication(
3167 outer_output_tmp_sharding, output_slice_dims);
3168 const double all_gather_time_in_ms = visitor->GetCommunicationTimeInMilliSec(
3169 all_gather_bytes, visitor->CreateReplicaGroups(all_gather_subgroups));
3170 const double reduce_scatter_time_in_ms =
3171 visitor->GetCommunicationTimeInMilliSec(
3172 reduce_scatter_bytes,
3173 visitor->CreateReplicaGroups(reduce_scatter_subgroups));
3174
3175 Shape other_original_shape = other_hlo->shape();
3176 *other_hlo->mutable_shape() =
3177 GetPerGroupBaseShape(other_grouped, other_base_shape);
3178 HloInstruction* dot =
3179 create_sharded_dot(lhs_matching_iterations ? lhs.hlo() : other_hlo,
3180 lhs_matching_iterations ? other_hlo : rhs.hlo(), b,
3181 conv_window)
3182 .ValueOrDie();
3183 const double computation_time_in_ms =
3184 visitor->GetComputationTimeInMilliSec(dot);
3185 *other_hlo->mutable_shape() = other_original_shape;
3186
3187 VLOG(2) << "lhs: " << lhs.hlo()->ToString() << "\n"
3188 << "rhs: " << rhs.hlo()->ToString() << "\n"
3189 << "new_num_partitions: " << new_num_partitions
3190 << " num_iterations: " << num_iterations << "\n"
3191 << "all_gather_bytes: " << all_gather_bytes
3192 << " reduce_scatter_bytes: " << reduce_scatter_bytes << "\n"
3193 << "all_gather_time_in_ms: " << all_gather_time_in_ms
3194 << " reduce_scatter_time_in_ms: " << reduce_scatter_time_in_ms << "\n"
3195 << "dot: " << dot->ToString() << "\n"
3196 << "computation_time_in_ms: " << computation_time_in_ms;
3197 if (computation_time_in_ms == 0.0 || all_gather_time_in_ms == 0.0 ||
3198 reduce_scatter_time_in_ms == 0.0) {
3199 const int64_t min_nc_iterations = std::min(
3200 lhs_matching_iterations ? *lhs_matching_iterations : INT64_MAX,
3201 rhs_matching_iterations ? *rhs_matching_iterations : INT64_MAX);
3202 return min_nc_iterations > new_num_partitions;
3203 } else if ((computation_time_in_ms <= all_gather_time_in_ms) &&
3204 (computation_time_in_ms <= reduce_scatter_time_in_ms)) {
3205 return all_gather_bytes / new_num_partitions <
3206 reduce_scatter_bytes / num_iterations;
3207 } else {
3208 return all_gather_time_in_ms > reduce_scatter_time_in_ms;
3209 }
3210 }
3211
3212 // Return if it would be better to match the LHS operand or RHS operand
3213 // of a dot for non-contracting partitioning.
LhsIsBestMatchForNonContractingPartitioning(const DotConvDimsMapping & dims_mapping,const PartitionedHlo & lhs,const PartitionedHlo & rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const SpmdPartitionerOptions & options,int64_t num_partitions,int64_t lhs_non_contracting_partitions,int64_t rhs_non_contracting_partitions,int64_t lhs_matching_partitions,int64_t rhs_matching_partitions,int64_t lhs_contracting_partitions,int64_t rhs_contracting_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,int64_t lhs_batch_partitions,int64_t rhs_batch_partitions,SpmdBuilder * b,const Window & conv_window,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,SpmdPartitioningVisitor * visitor)3214 bool LhsIsBestMatchForNonContractingPartitioning(
3215 const DotConvDimsMapping& dims_mapping, const PartitionedHlo& lhs,
3216 const PartitionedHlo& rhs, const Shape& output_base_shape,
3217 const HloSharding& output_sharding, const SpmdPartitionerOptions& options,
3218 int64_t num_partitions, int64_t lhs_non_contracting_partitions,
3219 int64_t rhs_non_contracting_partitions, int64_t lhs_matching_partitions,
3220 int64_t rhs_matching_partitions, int64_t lhs_contracting_partitions,
3221 int64_t rhs_contracting_partitions,
3222 int64_t output_lhs_non_contracting_partitions,
3223 int64_t output_rhs_non_contracting_partitions, int64_t lhs_batch_partitions,
3224 int64_t rhs_batch_partitions, SpmdBuilder* b, const Window& conv_window,
3225 const std::function<StatusOr<HloInstruction*>(
3226 HloInstruction*, HloInstruction*, SpmdBuilder*,
3227 const Window& conv_window)>& create_sharded_dot,
3228 SpmdPartitioningVisitor* visitor) {
3229 const bool may_group_on_lhs_non_contracting =
3230 lhs_non_contracting_partitions == output_lhs_non_contracting_partitions &&
3231 lhs_non_contracting_partitions > 1;
3232 const bool may_group_on_rhs_non_contracting =
3233 rhs_non_contracting_partitions == output_rhs_non_contracting_partitions &&
3234 rhs_non_contracting_partitions > 1;
3235 // If both match output non-contracting dimensions, choose the one which
3236 // will result in smaller replication of the other operand.
3237 bool lhs_matching = may_group_on_lhs_non_contracting &&
3238 (!may_group_on_rhs_non_contracting ||
3239 lhs_non_contracting_partitions *
3240 ShapeUtil::ByteSizeOf(rhs.hlo()->shape()) <
3241 rhs_non_contracting_partitions *
3242 ShapeUtil::ByteSizeOf(lhs.hlo()->shape()));
3243 // If both grouping are available and the option to choose faster windowed
3244 // einsums vs saving memory is enabled then try to determine which of the
3245 // operands will have more overlapping benefits for the windowed einsum
3246 // when matched (if a windowed einsum is gonna be generated at all).
3247 // 1) When computation is shorter than both all_gathers, we choose to overlap
3248 // with the smaller all_gather as it has potentially smaller extra
3249 // collective-permute overhead outside of the while loop; 2) Otherwise, we
3250 // choose the all_gather with longer runtime to overlap with.
3251 if (may_group_on_lhs_non_contracting && may_group_on_rhs_non_contracting &&
3252 options.choose_faster_windowed_einsum_over_mem) {
3253 const DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping(
3254 dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(),
3255 output_base_shape.rank());
3256 std::optional<int64_t> lhs_matching_iterations;
3257 std::optional<int64_t> rhs_matching_iterations;
3258 std::tie(lhs_matching_iterations, rhs_matching_iterations) =
3259 EstimateWindowedEinsumIterationsForNonContractingPartitioning(
3260 dims_mapping, lhs, rhs, output_base_shape, output_sharding, options,
3261 num_partitions, lhs_non_contracting_partitions,
3262 rhs_non_contracting_partitions, lhs_matching_partitions,
3263 rhs_matching_partitions, lhs_contracting_partitions,
3264 rhs_contracting_partitions, output_lhs_non_contracting_partitions,
3265 output_rhs_non_contracting_partitions, lhs_batch_partitions,
3266 rhs_batch_partitions, conv_window);
3267 if (lhs_matching_iterations && rhs_matching_iterations) {
3268 const int64_t lhs_all_gather_bytes =
3269 ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) *
3270 rhs_non_contracting_partitions;
3271 const int64_t rhs_all_gather_bytes =
3272 ShapeUtil::ByteSizeOf(rhs.hlo()->shape()) *
3273 lhs_non_contracting_partitions;
3274 auto lhs_grouped =
3275 GetNonContractingPartitionGroupedShardingForMatchedOperand(
3276 /*lhs_matching=*/true, lhs.sharding(), output_sharding,
3277 dims_mapping.lhs_non_contracting_dims);
3278 auto lhs_all_gather_subgroups = lhs_grouped.device_groups;
3279 auto rhs_grouped =
3280 GetNonContractingPartitionGroupedShardingForMatchedOperand(
3281 /*lhs_matching=*/false, rhs.sharding(), output_sharding,
3282 dims_mapping.rhs_non_contracting_dims);
3283 auto rhs_all_gather_subgroups = rhs_grouped.device_groups;
3284 const double lhs_all_gather_time_in_ms =
3285 visitor->GetCommunicationTimeInMilliSec(
3286 lhs_all_gather_bytes,
3287 visitor->CreateReplicaGroups(lhs_all_gather_subgroups));
3288 const double rhs_all_gather_time_in_ms =
3289 visitor->GetCommunicationTimeInMilliSec(
3290 rhs_all_gather_bytes,
3291 visitor->CreateReplicaGroups(rhs_all_gather_subgroups));
3292
3293 HloInstruction* compute_lhs = lhs.hlo();
3294 Shape lhs_original_shape = compute_lhs->shape();
3295 *compute_lhs->mutable_shape() =
3296 GetPerGroupBaseShape(lhs_grouped, lhs.base_shape());
3297 HloInstruction* compute_rhs = rhs.hlo();
3298 Shape rhs_original_shape = compute_rhs->shape();
3299 *compute_rhs->mutable_shape() =
3300 GetPerGroupBaseShape(rhs_grouped, rhs.base_shape());
3301 HloInstruction* dot =
3302 create_sharded_dot(compute_lhs, compute_rhs, b, conv_window)
3303 .ValueOrDie();
3304 const double computation_time_in_ms =
3305 visitor->GetComputationTimeInMilliSec(dot);
3306 *compute_lhs->mutable_shape() = lhs_original_shape;
3307 *compute_rhs->mutable_shape() = rhs_original_shape;
3308
3309 VLOG(2) << "lhs: " << lhs.hlo()->ToString() << "\n"
3310 << "rhs: " << rhs.hlo()->ToString() << "\n"
3311 << "lhs_non_contracting_partitions: "
3312 << lhs_non_contracting_partitions
3313 << " rhs_non_contracting_partitions: "
3314 << rhs_non_contracting_partitions << "\n"
3315 << "lhs_matching_iterations: " << *lhs_matching_iterations
3316 << " rhs_matching_iterations: " << *rhs_matching_iterations
3317 << "\n"
3318 << "lhs_all_gather_bytes: " << lhs_all_gather_bytes
3319 << " rhs_all_gather_bytes: " << rhs_all_gather_bytes << "\n"
3320 << "lhs_all_gather_time_in_ms: " << lhs_all_gather_time_in_ms
3321 << " rhs_all_gather_time_in_ms: " << rhs_all_gather_time_in_ms
3322 << "\n"
3323 << "dot: " << dot->ToString() << "\n"
3324 << "computation_time_in_ms: " << computation_time_in_ms;
3325 if (computation_time_in_ms == 0.0 || lhs_all_gather_time_in_ms == 0.0 ||
3326 rhs_all_gather_time_in_ms == 0.0) {
3327 lhs_matching = *lhs_matching_iterations < *rhs_matching_iterations;
3328 } else if ((computation_time_in_ms <= lhs_all_gather_time_in_ms) &&
3329 (computation_time_in_ms <= rhs_all_gather_time_in_ms)) {
3330 lhs_matching = lhs_all_gather_bytes / rhs_non_contracting_partitions >
3331 rhs_all_gather_bytes / lhs_non_contracting_partitions;
3332 } else {
3333 lhs_matching = lhs_all_gather_time_in_ms > rhs_all_gather_time_in_ms;
3334 }
3335 }
3336 }
3337 return lhs_matching;
3338 }
3339
3340 // Recursive partitioning function. If there are partial dimensions matching
3341 // in the operands and output, group the devices and recursively partition
3342 // the in-group dot.
PartitionDot(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops,SpmdPartitioningVisitor * visitor)3343 StatusOr<HloInstruction*> PartitionDot(
3344 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
3345 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
3346 int64_t num_partitions,
3347 const std::function<StatusOr<HloInstruction*>(
3348 HloInstruction*, HloInstruction*, SpmdBuilder*,
3349 const Window& conv_window)>& create_sharded_dot,
3350 const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
3351 bool require_matching_devices_to_group,
3352 const SpmdPartitionerOptions& options, SpmdBuilder* b,
3353 std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
3354 windowed_dot_general_loops,
3355 SpmdPartitioningVisitor* visitor) {
3356 // If lhs‘ hlo and rhs' hlo are identical, make a copy for rhs.
3357 if (lhs.hlo() == rhs.hlo()) {
3358 auto copy_hlo = b->AddInstruction(HloInstruction::CreateUnary(
3359 rhs.hlo()->shape(), HloOpcode::kCopy, rhs.hlo()));
3360 copy_hlo->set_sharding(rhs.sharding());
3361 rhs = PartitionedHlo(copy_hlo, rhs.base_shape(), rhs.state());
3362 }
3363
3364 // lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output.
3365 auto get_partitions_for_dims =
3366 [&](const HloSharding& sharding,
3367 absl::Span<const DotConvDimsMapping::DimsMapping> dims,
3368 int lhs_rhs_or_output) {
3369 int64_t partitions = 1;
3370 if (sharding.IsTileMaximal()) {
3371 return partitions;
3372 }
3373 for (const auto& dim : dims) {
3374 if (lhs_rhs_or_output == 0) {
3375 partitions *= sharding.tile_assignment().dim(dim.lhs);
3376 } else if (lhs_rhs_or_output == 1) {
3377 partitions *= sharding.tile_assignment().dim(dim.rhs);
3378 } else {
3379 CHECK_EQ(lhs_rhs_or_output, 2);
3380 partitions *= sharding.tile_assignment().dim(dim.output);
3381 }
3382 }
3383 return partitions;
3384 };
3385 const int64_t lhs_batch_partitions =
3386 get_partitions_for_dims(lhs.sharding(), dims_mapping.batch_dims, 0);
3387 const int64_t rhs_batch_partitions =
3388 get_partitions_for_dims(rhs.sharding(), dims_mapping.batch_dims, 1);
3389 const int64_t output_batch_partitions =
3390 get_partitions_for_dims(output_sharding, dims_mapping.batch_dims, 2);
3391 const int64_t lhs_contracting_partitions =
3392 get_partitions_for_dims(lhs.sharding(), dims_mapping.contracting_dims, 0);
3393 const int64_t rhs_contracting_partitions =
3394 get_partitions_for_dims(rhs.sharding(), dims_mapping.contracting_dims, 1);
3395 const int64_t lhs_non_contracting_partitions = get_partitions_for_dims(
3396 lhs.sharding(), dims_mapping.lhs_non_contracting_dims, 0);
3397 const int64_t rhs_non_contracting_partitions = get_partitions_for_dims(
3398 rhs.sharding(), dims_mapping.rhs_non_contracting_dims, 1);
3399 const int64_t output_lhs_non_contracting_partitions = get_partitions_for_dims(
3400 output_sharding, dims_mapping.lhs_non_contracting_dims, 2);
3401 const int64_t output_rhs_non_contracting_partitions = get_partitions_for_dims(
3402 output_sharding, dims_mapping.rhs_non_contracting_dims, 2);
3403 const int64_t lhs_conv_spatial_partitions = get_partitions_for_dims(
3404 lhs.sharding(), dims_mapping.conv_spatial_dims, 0);
3405 const int64_t rhs_conv_spatial_partitions = get_partitions_for_dims(
3406 rhs.sharding(), dims_mapping.conv_spatial_dims, 1);
3407 const int64_t output_conv_spatial_partitions = get_partitions_for_dims(
3408 output_sharding, dims_mapping.conv_spatial_dims, 2);
3409 // Before we find partial matches along the dimensions, invoke base case
3410 // again without may_reshard_without_detecting_match.
3411
3412 // Try partition the purely spatially-partitioned convolution with
3413 // convolution spatial dimension partitioned or depthwise parallel
3414 // dimension partitioned.
3415 bool is_conv_spatial_dim_partitioned =
3416 (lhs_conv_spatial_partitions > 1 || rhs_conv_spatial_partitions > 1 ||
3417 output_conv_spatial_partitions > 1);
3418 bool is_conv_batch_or_contracting_dim_partitioned =
3419 (lhs_batch_partitions > 1 || rhs_batch_partitions > 1 ||
3420 output_batch_partitions > 1 ||
3421 (lhs_contracting_partitions > 1 && rhs_contracting_partitions > 1));
3422 if ((!dims_mapping.conv_spatial_dims.empty() &&
3423 is_conv_spatial_dim_partitioned &&
3424 !is_conv_batch_or_contracting_dim_partitioned) ||
3425 (original_hlo->opcode() == HloOpcode::kConvolution &&
3426 (original_hlo->batch_group_count() > 1 ||
3427 original_hlo->feature_group_count() > 1))) {
3428 // Partition with kernel_input_feature_dim > 1 and feature_group_count >
3429 // 1 is not supported.
3430 const auto& dnums = original_hlo->convolution_dimension_numbers();
3431 if (original_hlo->feature_group_count() > 1 &&
3432 rhs.hlo()->shape().dimensions(dnums.kernel_input_feature_dimension()) >
3433 1) {
3434 return nullptr;
3435 }
3436
3437 TF_ASSIGN_OR_RETURN(
3438 auto partitioned_conv,
3439 PartitionConvolution(lhs, rhs, output_base_shape, output_sharding,
3440 dims_mapping, create_sharded_dot, conv_window,
3441 original_hlo, num_partitions, options,
3442 lhs.state().partition_id, module, b));
3443
3444 if (partitioned_conv) {
3445 return partitioned_conv;
3446 }
3447
3448 // Recursively partition on different types of dimensions for
3449 // convolution. Case 0.a: Group partitions by feature group count.
3450 if (original_hlo->feature_group_count() > 1 ||
3451 original_hlo->batch_group_count() > 1) {
3452 std::optional<DotConvDimsMapping> new_dims_mapping;
3453 if (original_hlo->feature_group_count() > 1) {
3454 const int64_t input_feature_dim =
3455 original_hlo->convolution_dimension_numbers()
3456 .input_feature_dimension();
3457 const int64_t kernel_output_feature_dim =
3458 original_hlo->convolution_dimension_numbers()
3459 .kernel_output_feature_dimension();
3460 // If the input and output feature dims are not equal, we require the
3461 // feature_group_count to be evenly partitioned; otherwise, there will
3462 // be different padding in the input/output.
3463 // TODO(xla): Use halo exchange to solve this problem. Can be a
3464 // preprocessing that uses padding/slicing to make the shape evenly
3465 // shardable.
3466 if (lhs.base_shape().dimensions(input_feature_dim) ==
3467 rhs.base_shape().dimensions(kernel_output_feature_dim) ||
3468 (lhs.sharding().IsTiled() &&
3469 original_hlo->feature_group_count() %
3470 ShardCountAtDim(lhs.sharding(), input_feature_dim) ==
3471 0)) {
3472 new_dims_mapping = ConvertDimsMappingWithFeatureGroupCount(
3473 dims_mapping, original_hlo);
3474 }
3475 }
3476
3477 if (original_hlo->batch_group_count() > 1) {
3478 const int64_t input_batch_dim =
3479 original_hlo->convolution_dimension_numbers()
3480 .input_batch_dimension();
3481 const int64_t kernel_output_feature_dim =
3482 original_hlo->convolution_dimension_numbers()
3483 .kernel_output_feature_dimension();
3484 if (lhs.base_shape().dimensions(input_batch_dim) ==
3485 rhs.base_shape().dimensions(kernel_output_feature_dim) ||
3486 (lhs.sharding().IsTiled() &&
3487 original_hlo->batch_group_count() %
3488 ShardCountAtDim(lhs.sharding(), input_batch_dim) ==
3489 0)) {
3490 new_dims_mapping =
3491 ConvertDimsMappingWithBatchGroupCount(dims_mapping, original_hlo);
3492 }
3493 }
3494 if (!new_dims_mapping.has_value()) {
3495 return nullptr;
3496 }
3497
3498 const int64_t conv_lhs_contracting_partitions = get_partitions_for_dims(
3499 lhs.sharding(), new_dims_mapping->contracting_dims, 0);
3500 const int64_t conv_rhs_contracting_partitions = get_partitions_for_dims(
3501 rhs.sharding(), new_dims_mapping->contracting_dims, 1);
3502 const int64_t conv_lhs_non_contracting_partitions =
3503 get_partitions_for_dims(
3504 lhs.sharding(), new_dims_mapping->lhs_non_contracting_dims, 0);
3505 const int64_t conv_rhs_non_contracting_partitions =
3506 get_partitions_for_dims(
3507 rhs.sharding(), new_dims_mapping->rhs_non_contracting_dims, 1);
3508 const int64_t conv_lhs_batch_partitions = get_partitions_for_dims(
3509 lhs.sharding(), new_dims_mapping->batch_dims, 0);
3510 const int64_t conv_rhs_batch_partitions = get_partitions_for_dims(
3511 rhs.sharding(), new_dims_mapping->batch_dims, 1);
3512 const int64_t conv_output_batch_partitions = get_partitions_for_dims(
3513 output_sharding, new_dims_mapping->batch_dims, 2);
3514 if ((conv_lhs_batch_partitions == conv_output_batch_partitions ||
3515 conv_rhs_batch_partitions == conv_output_batch_partitions) &&
3516 conv_output_batch_partitions > 1) {
3517 TF_ASSIGN_OR_RETURN(
3518 auto try_partitioned_conv,
3519 PartitionDotGroupOnBatch(
3520 lhs, rhs, output_base_shape, output_sharding, *new_dims_mapping,
3521 num_partitions, conv_lhs_contracting_partitions,
3522 conv_rhs_contracting_partitions,
3523 conv_lhs_non_contracting_partitions,
3524 conv_rhs_non_contracting_partitions, create_sharded_dot,
3525 conv_window, module, original_hlo,
3526 require_matching_devices_to_group, options, b,
3527 windowed_dot_general_loops, visitor));
3528 if (try_partitioned_conv) {
3529 return try_partitioned_conv;
3530 }
3531 }
3532 return nullptr;
3533 }
3534 }
3535
3536 TF_ASSIGN_OR_RETURN(
3537 auto try_partitioned_dot,
3538 PartitionBaseCase(
3539 lhs, rhs, output_base_shape, output_sharding, dims_mapping,
3540 num_partitions, create_sharded_dot, conv_window, module, original_hlo,
3541 lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
3542 lhs_contracting_partitions, rhs_contracting_partitions,
3543 lhs_non_contracting_partitions, rhs_non_contracting_partitions,
3544 output_lhs_non_contracting_partitions,
3545 output_rhs_non_contracting_partitions, options, b,
3546 windowed_dot_general_loops,
3547 /*may_reshard_without_detecting_match=*/false, visitor));
3548 if (try_partitioned_dot) {
3549 return try_partitioned_dot;
3550 }
3551
3552 // Recursively partition on different types of dimensions.
3553 //
3554 // Case 1: Group partitions by batch.
3555 if ((lhs_batch_partitions == output_batch_partitions ||
3556 rhs_batch_partitions == output_batch_partitions) &&
3557 output_batch_partitions > 1) {
3558 TF_ASSIGN_OR_RETURN(
3559 auto dot,
3560 PartitionDotGroupOnBatch(
3561 lhs, rhs, output_base_shape, output_sharding, dims_mapping,
3562 num_partitions, lhs_contracting_partitions,
3563 rhs_contracting_partitions, lhs_non_contracting_partitions,
3564 rhs_non_contracting_partitions, create_sharded_dot, conv_window,
3565 module, original_hlo, require_matching_devices_to_group, options, b,
3566 windowed_dot_general_loops, visitor));
3567 if (dot) {
3568 return dot;
3569 }
3570 }
3571
3572 // Case 2: Group partitions by non-contracting dimensions.
3573 const bool may_group_on_lhs_non_contracting =
3574 lhs_non_contracting_partitions == output_lhs_non_contracting_partitions &&
3575 lhs_non_contracting_partitions > 1;
3576 const bool may_group_on_rhs_non_contracting =
3577 rhs_non_contracting_partitions == output_rhs_non_contracting_partitions &&
3578 rhs_non_contracting_partitions > 1;
3579 bool lhs_matching = false;
3580 std::vector<DotConvDimsMapping::DimsMapping> matching_dims;
3581 if (may_group_on_lhs_non_contracting || may_group_on_rhs_non_contracting) {
3582 lhs_matching = LhsIsBestMatchForNonContractingPartitioning(
3583 dims_mapping, lhs, rhs, output_base_shape, output_sharding, options,
3584 num_partitions, lhs_non_contracting_partitions,
3585 rhs_non_contracting_partitions, lhs_non_contracting_partitions,
3586 rhs_non_contracting_partitions, lhs_contracting_partitions,
3587 rhs_contracting_partitions, output_lhs_non_contracting_partitions,
3588 output_rhs_non_contracting_partitions, lhs_batch_partitions,
3589 rhs_batch_partitions, b, conv_window, create_sharded_dot, visitor);
3590 matching_dims = lhs_matching ? dims_mapping.lhs_non_contracting_dims
3591 : dims_mapping.rhs_non_contracting_dims;
3592 } else if (lhs_non_contracting_partitions > 1 &&
3593 output_lhs_non_contracting_partitions > 1) {
3594 lhs_matching = true;
3595 for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
3596 int64_t lhs_partitions = lhs.sharding().tile_assignment().dim(dim.lhs);
3597 if (lhs_partitions > 1 &&
3598 lhs_partitions == output_sharding.tile_assignment().dim(dim.output)) {
3599 matching_dims.push_back(dim);
3600 }
3601 }
3602 } else if (rhs_non_contracting_partitions > 1 &&
3603 output_rhs_non_contracting_partitions > 1) {
3604 lhs_matching = false;
3605 for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
3606 int64_t rhs_partitions = rhs.sharding().tile_assignment().dim(dim.rhs);
3607 if (rhs_partitions > 1 &&
3608 rhs_partitions == output_sharding.tile_assignment().dim(dim.output)) {
3609 matching_dims.push_back(dim);
3610 }
3611 }
3612 }
3613 const bool prioritize_contracting_for_faster_windowed_einsum =
3614 PrioritizeContractingDimensionsPartitioning(
3615 dims_mapping, lhs, rhs, output_base_shape, output_sharding, options,
3616 num_partitions, lhs_non_contracting_partitions,
3617 rhs_non_contracting_partitions, lhs_contracting_partitions,
3618 rhs_contracting_partitions, output_lhs_non_contracting_partitions,
3619 output_rhs_non_contracting_partitions, lhs_batch_partitions,
3620 rhs_batch_partitions, output_batch_partitions,
3621 require_matching_devices_to_group, b, conv_window, create_sharded_dot,
3622 visitor);
3623 if (!(matching_dims.empty() ||
3624 prioritize_contracting_for_faster_windowed_einsum)) {
3625 TF_ASSIGN_OR_RETURN(
3626 auto dot,
3627 PartitionDotGroupOnNonContracting(
3628 lhs_matching, lhs_matching ? lhs : rhs, lhs_matching ? rhs : lhs,
3629 lhs_matching ? lhs_contracting_partitions
3630 : rhs_contracting_partitions,
3631 lhs_matching ? rhs_contracting_partitions
3632 : lhs_contracting_partitions,
3633 matching_dims,
3634 lhs_matching ? rhs_non_contracting_partitions
3635 : lhs_non_contracting_partitions,
3636 lhs_matching ? output_rhs_non_contracting_partitions
3637 : output_lhs_non_contracting_partitions,
3638 output_base_shape, output_sharding, dims_mapping, num_partitions,
3639 create_sharded_dot, conv_window, module, original_hlo,
3640 require_matching_devices_to_group, options, b,
3641 windowed_dot_general_loops, visitor));
3642 if (dot) {
3643 return dot;
3644 }
3645 }
3646
3647 // Case 3: Group partitions by contracting dimensions.
3648 if (lhs_contracting_partitions == rhs_contracting_partitions &&
3649 lhs_contracting_partitions > 1) {
3650 TF_ASSIGN_OR_RETURN(
3651 auto dot,
3652 PartitionDotGroupOnContracting(
3653 lhs, rhs, dims_mapping.contracting_dims, output_batch_partitions,
3654 output_lhs_non_contracting_partitions,
3655 output_rhs_non_contracting_partitions, output_base_shape,
3656 output_sharding, dims_mapping, num_partitions, create_sharded_dot,
3657 conv_window, module, original_hlo,
3658 require_matching_devices_to_group, options, b,
3659 windowed_dot_general_loops, visitor));
3660 if (dot) {
3661 return dot;
3662 }
3663 }
3664 if (lhs_contracting_partitions > 1 && rhs_contracting_partitions > 1) {
3665 // If part of contracting dims match, try them.
3666 std::vector<DotConvDimsMapping::DimsMapping> matching_dims;
3667 for (const auto& dim : dims_mapping.contracting_dims) {
3668 int64_t lhs_partitions = lhs.sharding().tile_assignment().dim(dim.lhs);
3669 if (lhs_partitions > 1 &&
3670 lhs_partitions == rhs.sharding().tile_assignment().dim(dim.rhs)) {
3671 matching_dims.push_back(dim);
3672 }
3673 }
3674 if (!matching_dims.empty()) {
3675 TF_ASSIGN_OR_RETURN(
3676 auto dot, PartitionDotGroupOnContracting(
3677 lhs, rhs, matching_dims, output_batch_partitions,
3678 output_lhs_non_contracting_partitions,
3679 output_rhs_non_contracting_partitions,
3680 output_base_shape, output_sharding, dims_mapping,
3681 num_partitions, create_sharded_dot, conv_window, module,
3682 original_hlo, require_matching_devices_to_group,
3683 options, b, windowed_dot_general_loops, visitor));
3684 if (dot) {
3685 return dot;
3686 }
3687 }
3688 }
3689
3690 // Case 4: If operands are replicated but output is partially replicated,
3691 // recursive call with partial replication removed.
3692 if (lhs.sharding().IsReplicated() && rhs.sharding().IsReplicated() &&
3693 output_sharding.ReplicateOnLastTileDim()) {
3694 auto grouped_output = hlo_sharding_util::GroupShardingOnDims(
3695 output_sharding, {output_base_shape.rank()});
3696 auto inner_state = CreatePerGroupPartitioningState(
3697 lhs.state(), grouped_output.device_groups, b);
3698 TF_ASSIGN_OR_RETURN(
3699 auto dot,
3700 PartitionDot(PartitionedHlo(lhs.hlo(), lhs.base_shape(), inner_state),
3701 PartitionedHlo(rhs.hlo(), rhs.base_shape(), inner_state),
3702 output_base_shape, grouped_output.sharding, dims_mapping,
3703 output_sharding.NumTiles(), create_sharded_dot,
3704 conv_window, module, original_hlo, options, b,
3705 windowed_dot_general_loops, visitor));
3706 if (dot) {
3707 return dot;
3708 }
3709 }
3710
3711 // We failed to find partial matches, invoke base case again with
3712 // may_reshard_without_detecting_match.
3713 TF_ASSIGN_OR_RETURN(
3714 auto dot,
3715 PartitionBaseCase(
3716 lhs, rhs, output_base_shape, output_sharding, dims_mapping,
3717 num_partitions, create_sharded_dot, conv_window, module, original_hlo,
3718 lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
3719 lhs_contracting_partitions, rhs_contracting_partitions,
3720 lhs_non_contracting_partitions, rhs_non_contracting_partitions,
3721 output_lhs_non_contracting_partitions,
3722 output_rhs_non_contracting_partitions, options, b,
3723 windowed_dot_general_loops,
3724 /*may_reshard_without_detecting_match=*/true, visitor));
3725 if (dot) {
3726 return dot;
3727 }
3728 return nullptr;
3729 }
3730
PartitionDot(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops,SpmdPartitioningVisitor * visitor)3731 StatusOr<HloInstruction*> PartitionDot(
3732 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
3733 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
3734 int64_t num_partitions,
3735 const std::function<StatusOr<HloInstruction*>(
3736 HloInstruction*, HloInstruction*, SpmdBuilder*,
3737 const Window& conv_window)>& create_sharded_dot,
3738 const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
3739 const SpmdPartitionerOptions& options, SpmdBuilder* b,
3740 std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
3741 windowed_dot_general_loops,
3742 SpmdPartitioningVisitor* visitor) {
3743 // First try partitioning without resharding the groups, then try allow
3744 // resharding the groups.
3745 for (bool require_matching_devices_to_group : {true, false}) {
3746 TF_ASSIGN_OR_RETURN(
3747 auto try_partition,
3748 PartitionDot(lhs, rhs, output_base_shape, output_sharding, dims_mapping,
3749 num_partitions, create_sharded_dot, conv_window, module,
3750 original_hlo, require_matching_devices_to_group, options,
3751 b, windowed_dot_general_loops, visitor));
3752 if (try_partition) {
3753 return try_partition;
3754 }
3755 }
3756
3757 // Default action.
3758 TF_ASSIGN_OR_RETURN(
3759 auto dot, create_sharded_dot(lhs.Replicate().hlo(), rhs.Replicate().hlo(),
3760 b, conv_window));
3761 dot->set_sharding(HloSharding::Replicate());
3762 return PartitionedHlo(dot, output_base_shape, lhs.state())
3763 .Reshard(output_sharding)
3764 .hlo();
3765 }
3766
3767 } // namespace
3768
HandleDotHelper(HloInstruction * hlo,const DotConvDimsMapping & dims_mapping,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot)3769 Status SpmdPartitioningVisitor::HandleDotHelper(
3770 HloInstruction* hlo, const DotConvDimsMapping& dims_mapping,
3771 const std::function<StatusOr<HloInstruction*>(
3772 HloInstruction*, HloInstruction*, SpmdBuilder*,
3773 const Window& conv_window)>& create_sharded_dot) {
3774 if (hlo->sharding().HasUniqueDevice()) {
3775 return DefaultAction(hlo);
3776 }
3777 auto& lhs = GetPartitionedHlo(hlo->operand(0));
3778 auto& rhs = GetPartitionedHlo(hlo->operand(1));
3779 Window conv_window;
3780 if (hlo->opcode() == HloOpcode::kConvolution) {
3781 conv_window = hlo->window();
3782 }
3783
3784 TF_ASSIGN_OR_RETURN(
3785 auto partitioned_dot,
3786 PartitionDot(lhs, rhs, hlo->shape(), hlo->sharding(), dims_mapping,
3787 num_partitions_, create_sharded_dot, conv_window, module_,
3788 hlo, options_, &b_, &windowed_dot_general_loops_, this));
3789 SetPartitionedHlo(hlo, [&] { return partitioned_dot; });
3790 return OkStatus();
3791 }
3792
3793 namespace {
3794
3795 // Finds a cluster of nodes that produce the inputs for `hlo` which only
3796 // depend on small operands, which means the cluster should start with
3797 // broadcasts, constants and iotas. All other internal nodes must be
3798 // non-side-effecting elemntwise ops. Returns the set of nodes, and the small
3799 // operands. E.g., for the following graph,
3800 //
3801 // a -> broadcast -> multiply
3802 // iota ---> add--/
3803 // constant/
3804 //
3805 // FindInputNodesIfOnlyDependOnSmallOperands(multiply) will return
3806 // <{broadcast, iota, constant, add, multiply}, [a]>.
3807 std::pair<absl::flat_hash_set<HloInstruction*>, std::vector<HloInstruction*>>
FindInputNodesIfOnlyDependOnSmallOperands(HloInstruction * hlo)3808 FindInputNodesIfOnlyDependOnSmallOperands(HloInstruction* hlo) {
3809 absl::flat_hash_set<HloInstruction*> nodes_found;
3810 std::vector<HloInstruction*> new_operands;
3811 absl::flat_hash_set<const HloInstruction*> new_operands_set;
3812 std::vector<HloInstruction*> worklist;
3813 worklist.push_back(hlo);
3814 while (!worklist.empty()) {
3815 auto inst = worklist.back();
3816 worklist.pop_back();
3817 if (nodes_found.count(inst) > 0) {
3818 continue;
3819 }
3820 if (inst->opcode() == HloOpcode::kBroadcast ||
3821 inst->opcode() == HloOpcode::kConstant ||
3822 inst->opcode() == HloOpcode::kIota) {
3823 nodes_found.insert(inst);
3824 for (auto o : inst->operands()) {
3825 auto res = new_operands_set.emplace(o);
3826 if (res.second) {
3827 new_operands.push_back(o);
3828 }
3829 }
3830 } else if (inst->IsElementwise() && !inst->HasSideEffectNoRecurse() &&
3831 absl::c_all_of(inst->operands(),
3832 [inst](const HloInstruction* o) {
3833 return ShapeUtil::CompatibleIgnoringElementType(
3834 o->shape(), inst->shape());
3835 })) {
3836 nodes_found.insert(inst);
3837 for (auto o : inst->operands()) {
3838 worklist.push_back(o);
3839 }
3840 } else {
3841 nodes_found.clear();
3842 new_operands.clear();
3843 break;
3844 }
3845 }
3846 return {std::move(nodes_found), std::move(new_operands)};
3847 }
3848
3849 // Moves a cluster of memory-reducing nodes into the windowed dot-general loop
3850 // on contracting dimensions. Such a loop has a dynamic slice on the
3851 // non-windowed operand. If we move the input nodes into the loop, the
3852 // dynamic-slice could be merged with them by later optimization passes, which
3853 // reduces memory.
3854 //
3855 // small_operands small_operands
3856 // | |
3857 // input_nodes loop { |
3858 // | => input_nodes
3859 // loop { | |
3860 // dynamic-slice dynamic-slice
3861 // ... ...
3862 // } }
3863 //
3864 // Later optimization passes (TpuPadSliceMover) will merge the dynamic slice
3865 // with the input nodes.
SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions(HloInstruction * loop,int64_t non_windowed_operand_index)3866 Status SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions(
3867 HloInstruction* loop, int64_t non_windowed_operand_index) {
3868 auto input_tuple = loop->mutable_operand(0);
3869 auto old_operand = input_tuple->mutable_operand(non_windowed_operand_index);
3870 auto input_nodes = FindInputNodesIfOnlyDependOnSmallOperands(old_operand);
3871 auto to_sink = std::move(input_nodes.first);
3872 auto new_operands = std::move(input_nodes.second);
3873 if (to_sink.empty()) {
3874 return OkStatus();
3875 }
3876 auto computation = loop->parent();
3877 // Replace the old operand with a tuple of the found small operands.
3878 auto new_input_subtuple =
3879 computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
3880 TF_RETURN_IF_ERROR(input_tuple->ReplaceOperandWithDifferentShape(
3881 non_windowed_operand_index, new_input_subtuple));
3882
3883 auto body = loop->while_body();
3884 auto body_param = body->parameter_instruction(0);
3885 auto old_body_param_users = body_param->users();
3886 // Update all tuple shapes.
3887 for (auto tuple : std::vector<HloInstruction*>{
3888 input_tuple, loop, loop->while_condition()->parameter_instruction(0),
3889 body_param, body->root_instruction()}) {
3890 *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(),
3891 {non_windowed_operand_index}) =
3892 new_input_subtuple->shape();
3893 }
3894 // Now update the loop body.
3895 auto new_operand_tuple_inside =
3896 body->AddInstruction(HloInstruction::CreateGetTupleElement(
3897 new_input_subtuple->shape(), body_param, non_windowed_operand_index));
3898 TF_RETURN_IF_ERROR(body->root_instruction()->ReplaceOperandWithDifferentShape(
3899 non_windowed_operand_index, new_operand_tuple_inside));
3900
3901 // Create nodes inside the loop body.
3902 std::vector<HloInstruction*> worklist;
3903 absl::flat_hash_map<const HloInstruction*, HloInstruction*> outside_to_inside;
3904 auto add_users_if_available = [&](HloInstruction* inst) {
3905 for (auto u : inst->users()) {
3906 if (outside_to_inside.count(u) == 0 && to_sink.count(u) > 0 &&
3907 absl::c_all_of(u->operands(), [&](const HloInstruction* o) {
3908 return outside_to_inside.count(o) > 0;
3909 })) {
3910 worklist.push_back(u);
3911 }
3912 }
3913 };
3914 for (int64_t i = 0; i < new_operands.size(); ++i) {
3915 outside_to_inside[new_operands[i]] =
3916 body->AddInstruction(HloInstruction::CreateGetTupleElement(
3917 new_operands[i]->shape(), new_operand_tuple_inside, i));
3918 add_users_if_available(new_operands[i]);
3919 }
3920 // HLOs to sink without operands.
3921 std::vector<HloInstruction*> nullaries_to_sink;
3922 for (auto inst : to_sink) {
3923 if (inst->operand_count() == 0) {
3924 nullaries_to_sink.push_back(inst);
3925 }
3926 }
3927 // Sort nullaries_to_sink to make it deterministic.
3928 absl::c_sort(nullaries_to_sink,
3929 [](const HloInstruction* a, const HloInstruction* b) {
3930 return a->unique_id() < b->unique_id();
3931 });
3932 worklist.reserve(nullaries_to_sink.size());
3933 for (auto inst : nullaries_to_sink) {
3934 worklist.push_back(inst);
3935 }
3936 while (!worklist.empty()) {
3937 auto inst = worklist.back();
3938 worklist.pop_back();
3939 std::vector<HloInstruction*> inst_new_operands(inst->operand_count());
3940 for (int64_t i = 0; i < inst->operand_count(); ++i) {
3941 inst_new_operands[i] = outside_to_inside[inst->operand(i)];
3942 }
3943 outside_to_inside[inst] = body->AddInstruction(
3944 inst->CloneWithNewOperands(inst->shape(), inst_new_operands));
3945 add_users_if_available(inst);
3946 }
3947 TF_RET_CHECK(outside_to_inside.count(old_operand) > 0);
3948 for (auto ou : old_body_param_users) {
3949 if (ou->opcode() == HloOpcode::kGetTupleElement &&
3950 ou->tuple_index() == non_windowed_operand_index) {
3951 TF_RETURN_IF_ERROR(
3952 ou->ReplaceAllUsesWith(outside_to_inside[old_operand]));
3953 TF_RETURN_IF_ERROR(body->RemoveInstruction(ou));
3954 }
3955 }
3956 return OkStatus();
3957 }
3958
3959 // Checks a condition holds true for all recursive operands of an hlo.
CheckOperandsRecursive(const HloInstruction * hlo,std::function<bool (const HloInstruction *)> check)3960 bool CheckOperandsRecursive(const HloInstruction* hlo,
3961 std::function<bool(const HloInstruction*)> check) {
3962 std::deque<const HloInstruction*> worklist;
3963 worklist.push_front(hlo);
3964 while (!worklist.empty()) {
3965 auto inst = worklist.back();
3966 worklist.pop_back();
3967 for (HloInstruction* operand : inst->operands()) {
3968 if (!check(operand)) {
3969 return false;
3970 }
3971 worklist.push_front(operand);
3972 }
3973 }
3974 return true;
3975 }
3976
3977 // Moves a cluster of memory-reducing nodes (with reduce nodes at the end)
3978 // into the windowed dot-general loop on non-contracting dimensions. Such a
3979 // loop has a dynamic-update-slice at the output. If we move the user nodes
3980 // into the loop and before the dynamic-update-slice, the user nodes can
3981 // operate on smaller shapes, which reduces memory.
3982 //
3983 // small_operands small_operands
3984 // | | => | |
3985 // | | loop { loop { | |
3986 // | | conv | broadcast conv
3987 // | | | | | /
3988 // | | dynamic-update-slice | dynamic-slice /
3989 // | | | | | /
3990 // | | } | | multiply-----
3991 // |broadcast / | /
3992 // | | / reduce
3993 // |multiply-- |
3994 // \ | dynamic-update-slice
3995 // reduce }
3996 //
3997 // Later optimization passes (TpuPadSliceMover) will merge the dynamic slice
3998 // with the input nodes (broadcast).
MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions(HloInstruction * loop,const SpmdPartitionerOptions & options)3999 Status MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions(
4000 HloInstruction* loop, const SpmdPartitionerOptions& options) {
4001 CHECK_EQ(loop->user_count(), 1);
4002 // There should be a single direct user of the while loop, which is the
4003 // gte for element 2, i.e., the dot output.
4004 auto* user_gte = loop->users().front();
4005 CHECK_EQ(user_gte->opcode(), HloOpcode::kGetTupleElement);
4006 CHECK_EQ(user_gte->tuple_index(), 2);
4007 auto* computation = loop->parent();
4008
4009 // Find the reduce outputs and the input nodes they depend on, if input
4010 // nodes only have small operands.
4011 absl::flat_hash_set<HloInstruction*> to_move;
4012 std::vector<HloInstruction*> new_operands;
4013 absl::flat_hash_set<const HloInstruction*> new_operands_set;
4014 std::vector<HloInstruction*> reduce_outputs;
4015 std::vector<HloInstruction*> worklist;
4016 Shape padded_shape = user_gte->shape();
4017 Shape unpadded_shape = user_gte->shape();
4018 auto* original_output = user_gte;
4019
4020 if (user_gte->user_count() == 1 &&
4021 user_gte->users().back()->opcode() == HloOpcode::kSlice) {
4022 original_output = user_gte->users().back();
4023 unpadded_shape = original_output->shape();
4024 }
4025 for (auto* u : original_output->users()) {
4026 worklist.push_back(u);
4027 }
4028 to_move.insert(original_output);
4029 while (!worklist.empty()) {
4030 auto* inst = worklist.back();
4031 worklist.pop_back();
4032 if (to_move.count(inst) > 0) {
4033 continue;
4034 }
4035 // We only support reduces with simple reduction function, since we may
4036 // need to accumulate across iterations manually.
4037 if (inst->opcode() == HloOpcode::kReduce &&
4038 inst->to_apply()->instruction_count() == 3 &&
4039 inst->to_apply()->num_parameters() == 2 &&
4040 inst->to_apply()->root_instruction()->IsElementwise()) {
4041 to_move.insert(inst);
4042 auto* other_operand = inst->mutable_operand(1);
4043 auto res = new_operands_set.emplace(other_operand);
4044 if (res.second) {
4045 new_operands.push_back(other_operand);
4046 }
4047 reduce_outputs.push_back(inst);
4048 } else if (inst != computation->root_instruction() &&
4049 inst->user_count() > 0 && inst->IsElementwise() &&
4050 !inst->HasSideEffectNoRecurse() &&
4051 absl::c_all_of(inst->operands(),
4052 [inst](const HloInstruction* o) {
4053 return ShapeUtil::CompatibleIgnoringElementType(
4054 o->shape(), inst->shape());
4055 })) {
4056 // For an elementwise op, we need to make sure that they depend on only
4057 // nodes already in to_move and nodes with small operands.
4058 bool can_include = true;
4059 for (auto* operand : inst->operands()) {
4060 if (to_move.count(operand) > 0) {
4061 continue;
4062 }
4063 auto find_result = FindInputNodesIfOnlyDependOnSmallOperands(operand);
4064 if (find_result.first.empty()) {
4065 can_include = false;
4066 break;
4067 }
4068 for (auto* n : find_result.first) {
4069 to_move.insert(n);
4070 }
4071 for (auto* new_operand : find_result.second) {
4072 auto res = new_operands_set.insert(new_operand);
4073 if (res.second) {
4074 new_operands.push_back(new_operand);
4075 }
4076 }
4077 }
4078 if (!can_include) {
4079 to_move.clear();
4080 break;
4081 }
4082 to_move.insert(inst);
4083 for (auto* u : inst->users()) {
4084 worklist.push_back(u);
4085 }
4086 } else {
4087 to_move.clear();
4088 break;
4089 }
4090 }
4091 // If nothing is found, to_move could contain only original_output, or
4092 // cleared by the above code.
4093 if (to_move.size() <= 1) {
4094 return OkStatus();
4095 }
4096
4097 // If there is a reduce that's dependent of another reduce, then we can't do
4098 // code motion, as it will create a circular dependencies.
4099 if (reduce_outputs.size() > 10) {
4100 // When there are many reduces, it might be faster to build a reachibility
4101 // map, and then check pair-wise reachability.
4102 auto reachability = HloReachabilityMap::Build(computation);
4103 for (const HloInstruction* reduce : reduce_outputs) {
4104 for (const HloInstruction* other_reduce : reduce_outputs) {
4105 if (reduce != other_reduce &&
4106 reachability->IsReachable(reduce, other_reduce)) {
4107 return OkStatus();
4108 }
4109 }
4110 }
4111 } else if (reduce_outputs.size() > 1) {
4112 // When there are only few reduces, we can do traversal to check dependency.
4113 for (const HloInstruction* reduce : reduce_outputs) {
4114 auto reduce_outputs_do_not_contain = [&](const HloInstruction* inst) {
4115 return !absl::c_linear_search(reduce_outputs, inst);
4116 };
4117 if (!CheckOperandsRecursive(reduce, reduce_outputs_do_not_contain)) {
4118 return OkStatus();
4119 }
4120 }
4121 }
4122
4123 // We will replace the original loop output with reduce-shape outputs.
4124 // Create the initial buffers before the loop.
4125 for (auto* out : reduce_outputs) {
4126 Shape padded_out_shape = out->shape();
4127 int64_t operand_dim = 0;
4128 int64_t output_dim = 0;
4129 while (output_dim < padded_out_shape.rank()) {
4130 if (absl::c_linear_search(out->dimensions(), operand_dim)) {
4131 // Dimension colapsed.
4132 ++operand_dim;
4133 continue;
4134 }
4135 // Kept dimensions have the same size of the padded shape.
4136 padded_out_shape.set_dimensions(output_dim,
4137 padded_shape.dimensions(operand_dim));
4138 ++operand_dim;
4139 ++output_dim;
4140 }
4141 auto* broadcast =
4142 computation->AddInstruction(HloInstruction::CreateBroadcast(
4143 padded_out_shape,
4144 computation->AddInstruction(HloInstruction::CreateConstant(
4145 LiteralUtil::Zero(out->shape().element_type()))),
4146 {}));
4147 new_operands.push_back(broadcast);
4148 }
4149
4150 auto* input_tuple = loop->mutable_operand(0);
4151 // Create the new input subtuple that contains the small operands and the
4152 // reduce-shape result buffers.
4153 auto* new_input_subtuple =
4154 computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
4155 TF_RETURN_IF_ERROR(
4156 input_tuple->ReplaceOperandWithDifferentShape(2, new_input_subtuple));
4157 auto* body = loop->while_body();
4158 auto* body_param = body->parameter_instruction(0);
4159 auto* body_root = body->root_instruction();
4160 CHECK_EQ(body_root->opcode(), HloOpcode::kTuple);
4161 // Update tuple shapes.
4162 for (auto* tuple : std::vector<HloInstruction*>{
4163 input_tuple, loop, loop->while_condition()->parameter_instruction(0),
4164 body_param, body_root}) {
4165 *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {2}) =
4166 new_input_subtuple->shape();
4167 }
4168 auto* new_loop_input =
4169 body->AddInstruction(HloInstruction::CreateGetTupleElement(
4170 new_input_subtuple->shape(), body_param, 2));
4171
4172 // Represents a cluster that's associated to a single dynamic-update-slice op,
4173 // which should be moved to inside of the windowed dot-general loop. There
4174 // might be multiple clusters associated with multiple dynamic-update-slice
4175 // ops which all need moving.
4176 struct MotionCluster {
4177 HloInstruction* dus;
4178 absl::flat_hash_map<const HloInstruction*, HloInstruction*>
4179 outside_to_inside;
4180 std::vector<HloInstruction*> slice_offsets;
4181 };
4182
4183 std::vector<MotionCluster> motion_clusters;
4184
4185 // The elementwise nodes will be created with sliced shape. The original
4186 // loop output corresponds to the dynamic-update-slice's update slice.
4187 {
4188 HloInstruction* dus = body_root->mutable_operand(2);
4189 while (dus->opcode() == HloOpcode::kDynamicUpdateSlice) {
4190 motion_clusters.emplace_back();
4191 motion_clusters.back().dus = dus;
4192 motion_clusters.back().outside_to_inside[original_output] =
4193 dus->mutable_operand(1);
4194 motion_clusters.back().slice_offsets.reserve(padded_shape.rank());
4195 for (int64_t i = 0; i < padded_shape.rank(); ++i) {
4196 motion_clusters.back().slice_offsets.push_back(
4197 motion_clusters.back().dus->mutable_operand(i + 2));
4198 }
4199 dus = dus->mutable_operand(0);
4200 }
4201 }
4202 // This is at least one cluster that needs moving.
4203 CHECK_GE(motion_clusters.size(), 1);
4204 MotionCluster& base_motion_cluster = motion_clusters[0];
4205
4206 worklist.clear();
4207 auto add_users_if_available = [&](HloInstruction* inst) {
4208 for (auto* u : inst->users()) {
4209 if (base_motion_cluster.outside_to_inside.count(u) == 0 &&
4210 to_move.count(u) > 0 &&
4211 absl::c_all_of(u->operands(), [&](const HloInstruction* o) {
4212 return base_motion_cluster.outside_to_inside.count(o) > 0;
4213 })) {
4214 worklist.push_back(u);
4215 }
4216 }
4217 };
4218
4219 for (int64_t i = 0; i < new_operands.size(); ++i) {
4220 auto* operand_gte =
4221 body->AddInstruction(HloInstruction::CreateGetTupleElement(
4222 new_operands[i]->shape(), new_loop_input, i));
4223 for (MotionCluster& motion_cluster : motion_clusters) {
4224 motion_cluster.outside_to_inside[new_operands[i]] = operand_gte;
4225 }
4226 add_users_if_available(new_operands[i]);
4227 }
4228 add_users_if_available(original_output);
4229
4230 // Now create the moved nodes inside the loop body.
4231 auto get_slice = [&](HloInstruction* padded,
4232 absl::Span<HloInstruction* const> slice_offsets,
4233 HloInstruction* dus) {
4234 return body->AddInstruction(HloInstruction::CreateDynamicSlice(
4235 ShapeUtil::ChangeElementType(dus->operand(1)->shape(),
4236 padded->shape().element_type()),
4237 padded, slice_offsets, dus->operand(1)->shape().dimensions()));
4238 };
4239 // Helper functions to create nodes with small operands.
4240 auto add_broadcast = [&](const HloInstruction* broadcast) {
4241 Shape padded_operand_shape = broadcast->operand(0)->shape();
4242 for (int64_t i = 0; i < broadcast->dimensions().size(); ++i) {
4243 padded_operand_shape.set_dimensions(
4244 i, padded_shape.dimensions(broadcast->dimensions(i)));
4245 }
4246 auto* padded_operand =
4247 PadToShape(base_motion_cluster.outside_to_inside[broadcast->operand(0)],
4248 padded_operand_shape, body);
4249 auto* inside_broadcast =
4250 body->AddInstruction(broadcast->CloneWithNewOperands(
4251 ShapeUtil::ChangeElementType(padded_shape,
4252 padded_operand_shape.element_type()),
4253 {padded_operand}));
4254 for (MotionCluster& motion_cluster : motion_clusters) {
4255 motion_cluster.outside_to_inside[broadcast] = get_slice(
4256 inside_broadcast, motion_cluster.slice_offsets, motion_cluster.dus);
4257 }
4258 };
4259 auto add_iota = [&](const HloInstruction* iota) {
4260 auto* inside_iota = body->AddInstruction(iota->CloneWithNewOperands(
4261 ShapeUtil::ChangeElementType(padded_shape,
4262 iota->shape().element_type()),
4263 {}));
4264 for (MotionCluster& motion_cluster : motion_clusters) {
4265 motion_cluster.outside_to_inside[iota] = get_slice(
4266 inside_iota, motion_cluster.slice_offsets, motion_cluster.dus);
4267 }
4268 };
4269 auto add_constant = [&](const HloInstruction* constant) {
4270 auto* constant_clone = body->AddInstruction(constant->Clone());
4271 auto* inside_constant =
4272 PadToShape(constant_clone,
4273 ShapeUtil::ChangeElementType(
4274 padded_shape, constant->shape().element_type()),
4275 body);
4276 for (MotionCluster& motion_cluster : motion_clusters) {
4277 motion_cluster.outside_to_inside[constant] = get_slice(
4278 inside_constant, motion_cluster.slice_offsets, motion_cluster.dus);
4279 }
4280 };
4281 auto add_other_inst = [&](const HloInstruction* inst) {
4282 std::vector<HloInstruction*> operands_inside(inst->operand_count());
4283 for (MotionCluster& motion_cluster : motion_clusters) {
4284 for (int64_t i = 0; i < operands_inside.size(); ++i) {
4285 operands_inside[i] = motion_cluster.outside_to_inside[inst->operand(i)];
4286 }
4287 motion_cluster.outside_to_inside[inst] =
4288 body->AddInstruction(inst->CloneWithNewOperands(
4289 ShapeUtil::ChangeElementType(
4290 motion_cluster.dus->operand(1)->shape(),
4291 inst->shape().element_type()),
4292 operands_inside));
4293 }
4294 };
4295
4296 while (!worklist.empty()) {
4297 auto* inst = worklist.back();
4298 worklist.pop_back();
4299 if (absl::c_all_of(
4300 motion_clusters, [inst](const MotionCluster& motion_cluster) {
4301 return motion_cluster.outside_to_inside.count(inst) > 0;
4302 })) {
4303 continue;
4304 }
4305 if (inst->opcode() == HloOpcode::kBroadcast) {
4306 add_broadcast(inst);
4307 } else if (inst->opcode() == HloOpcode::kIota) {
4308 add_iota(inst);
4309 } else if (inst->opcode() == HloOpcode::kConstant) {
4310 add_constant(inst);
4311 } else if (inst->opcode() == HloOpcode::kReduce) {
4312 // This is an output, for which we has special handling later.
4313 } else if (inst->IsElementwise()) {
4314 add_other_inst(inst);
4315 } else {
4316 // Skip cloning other non-elementwise ops.
4317 }
4318 add_users_if_available(inst);
4319 }
4320 std::vector<HloInstruction*> new_outputs_inside(new_operands.size());
4321 for (int64_t i = 0; i < new_outputs_inside.size(); ++i) {
4322 new_outputs_inside[i] =
4323 base_motion_cluster.outside_to_inside[new_operands[i]];
4324 }
4325
4326 // Now create the reduce outputs inside of the loop.
4327 for (int64_t i = 0; i < reduce_outputs.size(); ++i) {
4328 auto* reduce_outside = reduce_outputs[i];
4329 CHECK_EQ(reduce_outside->opcode(), HloOpcode::kReduce);
4330 int64_t index_in_operand = new_operands.size() - reduce_outputs.size() + i;
4331 auto* last_iter_result =
4332 base_motion_cluster.outside_to_inside[new_operands[index_in_operand]];
4333
4334 auto create_inside_reduce =
4335 [&](absl::flat_hash_map<const HloInstruction*, HloInstruction*>&
4336 outside_to_inside,
4337 absl::Span<HloInstruction* const> slice_offsets,
4338 HloInstruction* last_iter_result) -> StatusOr<HloInstruction*> {
4339 HloInstruction* operand0 = outside_to_inside[reduce_outside->operand(0)];
4340 HloInstruction* operand1 = outside_to_inside[reduce_outside->operand(1)];
4341 TF_ASSIGN_OR_RETURN(
4342 Shape reduce_shape,
4343 ShapeInference::InferReduceShape(
4344 {&operand0->shape(), &operand1->shape()},
4345 reduce_outside->dimensions(),
4346 reduce_outside->to_apply()->ComputeProgramShape()));
4347 *reduce_shape.mutable_layout() = reduce_outside->shape().layout();
4348 std::vector<HloInstruction*> reduce_dus_offsets;
4349 // If any collapsed dimension is windowed, we need to accumulate with last
4350 // iteration's result. If such a dimension has padding, we also need to
4351 // mask off invalid data.
4352 bool needs_accumulate = false;
4353 std::vector<int64_t> dims_to_mask;
4354 for (int64_t i = 0; i < slice_offsets.size(); ++i) {
4355 if (absl::c_linear_search(reduce_outside->dimensions(), i)) {
4356 if (reduce_outside->operand(0)->shape().dimensions(i) !=
4357 operand0->shape().dimensions(i)) {
4358 needs_accumulate = true;
4359 if (unpadded_shape.dimensions(i) != padded_shape.dimensions(i)) {
4360 dims_to_mask.push_back(i);
4361 }
4362 }
4363 continue;
4364 }
4365 reduce_dus_offsets.push_back(slice_offsets[i]);
4366 }
4367 // Mask off invalid data in collapsed dimensions.
4368 for (int64_t dim : dims_to_mask) {
4369 auto* iota = body->AddInstruction(HloInstruction::CreateIota(
4370 ShapeUtil::ChangeElementType(operand0->shape(), S32), dim));
4371 auto* add = body->AddInstruction(HloInstruction::CreateBinary(
4372 iota->shape(), HloOpcode::kAdd, iota,
4373 body->AddInstruction(HloInstruction::CreateBroadcast(
4374 iota->shape(), slice_offsets[dim], {}))));
4375 auto* limit = body->AddInstruction(HloInstruction::CreateBroadcast(
4376 iota->shape(),
4377 body->AddInstruction(
4378 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
4379 reduce_outside->operand(0)->shape().dimensions(dim)))),
4380 {}));
4381 auto* compare = body->AddInstruction(HloInstruction::CreateCompare(
4382 ShapeUtil::ChangeElementType(iota->shape(), PRED), add, limit,
4383 ComparisonDirection::kLt));
4384 operand0 = body->AddInstruction(HloInstruction::CreateTernary(
4385 operand0->shape(), HloOpcode::kSelect, compare, operand0,
4386 body->AddInstruction(HloInstruction::CreateBroadcast(
4387 operand0->shape(), operand1, {}))));
4388 }
4389 auto* output_inside =
4390 body->AddInstruction(reduce_outside->CloneWithNewOperands(
4391 reduce_shape, {operand0, operand1}));
4392 // Accumulate with previous results if needed.
4393 if (needs_accumulate) {
4394 auto* input_slice =
4395 body->AddInstruction(HloInstruction::CreateDynamicSlice(
4396 output_inside->shape(), last_iter_result, reduce_dus_offsets,
4397 output_inside->shape().dimensions()));
4398 output_inside = body->AddInstruction(HloInstruction::CreateBinary(
4399 output_inside->shape(),
4400 reduce_outside->to_apply()->root_instruction()->opcode(),
4401 output_inside, input_slice));
4402 }
4403 // Dynamic-update-slice if needed.
4404 if (!ShapeUtil::Compatible(output_inside->shape(),
4405 last_iter_result->shape())) {
4406 output_inside =
4407 body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
4408 last_iter_result->shape(), last_iter_result, output_inside,
4409 reduce_dus_offsets));
4410 }
4411 return output_inside;
4412 };
4413 for (MotionCluster& motion_cluster : motion_clusters) {
4414 TF_ASSIGN_OR_RETURN(
4415 last_iter_result,
4416 create_inside_reduce(motion_cluster.outside_to_inside,
4417 motion_cluster.slice_offsets, last_iter_result));
4418 }
4419 new_outputs_inside[index_in_operand] = last_iter_result;
4420 }
4421
4422 // Body output.
4423 auto* new_output_inside =
4424 body->AddInstruction(HloInstruction::CreateTuple(new_outputs_inside));
4425 TF_RETURN_IF_ERROR(
4426 body_root->ReplaceOperandWithDifferentShape(2, new_output_inside));
4427 TF_RETURN_IF_ERROR(
4428 body->RemoveInstructionAndUnusedOperands(base_motion_cluster.dus));
4429 // Replace uses of the reduces outside the loop.
4430 auto* new_output_gte =
4431 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
4432 new_output_inside->shape(), loop, 2));
4433 for (int64_t i = 0; i < reduce_outputs.size(); ++i) {
4434 int64_t index_in_operand = new_operands.size() - reduce_outputs.size() + i;
4435 auto* new_output =
4436 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
4437 new_outputs_inside[index_in_operand]->shape(), new_output_gte,
4438 index_in_operand));
4439 if (!ShapeUtil::Compatible(new_output->shape(),
4440 reduce_outputs[i]->shape())) {
4441 new_output = computation->AddInstruction(HloInstruction::CreateSlice(
4442 reduce_outputs[i]->shape(), new_output,
4443 std::vector<int64_t>(new_output->shape().rank(), 0),
4444 reduce_outputs[i]->shape().dimensions(),
4445 std::vector<int64_t>(new_output->shape().rank(), 1)));
4446 }
4447 TF_RETURN_IF_ERROR(reduce_outputs[i]->ReplaceAllUsesWith(new_output));
4448 TF_RETURN_IF_ERROR(
4449 computation->RemoveInstructionAndUnusedOperands(reduce_outputs[i]));
4450 }
4451 return OkStatus();
4452 }
4453
4454 } // namespace
4455
DoCodeMotionForWindowedDotGeneralLoops(HloComputation * computation,const SpmdPartitionerOptions & options)4456 Status SpmdPartitioningVisitor::DoCodeMotionForWindowedDotGeneralLoops(
4457 HloComputation* computation, const SpmdPartitionerOptions& options) {
4458 for (auto& loop : windowed_dot_general_loops_) {
4459 if (loop.windowed_in_contracting_dims || loop.windowed_in_batch_dims ||
4460 loop.operands_sharded_at_contracting_dims) {
4461 // We have a dynamic-slice for the non-windowed operand in
4462 // batch/contracting-dim/noncontracting-dim windowed dot-general. So
4463 // moving the broadcast/iota/elementwise ops into the loop could help
4464 // reduce memory via fusion.
4465 TF_RETURN_IF_ERROR(
4466 SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions(
4467 loop.while_loop, 1 - loop.windowed_operand));
4468 }
4469 // Currently unrolled loop does not support this optimization.
4470 if (!loop.windowed_in_contracting_dims &&
4471 !loop.operands_sharded_at_contracting_dims) {
4472 // We have a dynamic-update-slice for the output in
4473 // batch/non-contracting-dim windowed dot-general. So moving reduce ops
4474 // into the loop could help reduce memory.
4475 TF_RETURN_IF_ERROR(
4476 MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions(
4477 loop.while_loop, options));
4478 }
4479 }
4480 return OkStatus();
4481 }
4482
4483 } // namespace spmd
4484 } // namespace xla
4485