1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <float.h>
17
18 #include <cmath>
19 #include <functional>
20 #include <memory>
21 #include <optional>
22 #include <vector>
23
24 #include "absl/algorithm/container.h"
25 #include "tensorflow/compiler/xla/client/lib/comparators.h"
26 #include "tensorflow/compiler/xla/comparison_util.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/protobuf_util.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
33 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
34 #include "tensorflow/compiler/xla/service/shape_inference.h"
35 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
36 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/util.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40
41 namespace xla {
42 namespace spmd {
43
44 namespace {
45
46 // Pad each partition to have size that is multiplication of num_partitions.
47 // For example, if input is {0, 1, 2, 3, 4, 5} and num_partitions = 2,
48 // after padding, it becomes {0, 1, 2, 3} in partition 0 and {4, 5, 0, 0} in
49 // partition 1.
PadEachPartitionWithHaloExchange(HloInstruction * hlo,int64_t num_partitions,const HloSharding & sharding,const SPMDCollectiveOpsCreator & collective_ops_creator,int64_t * next_channel_id,HloInstruction * partition_id,SpmdBuilder * b)50 std::optional<HloInstruction*> PadEachPartitionWithHaloExchange(
51 HloInstruction* hlo, int64_t num_partitions, const HloSharding& sharding,
52 const SPMDCollectiveOpsCreator& collective_ops_creator,
53 int64_t* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) {
54 int64_t size_per_partition = hlo->shape().dimensions().back();
55 int64_t size_padded_per_partition =
56 CeilOfRatio(size_per_partition, num_partitions) * num_partitions;
57 if (size_per_partition == size_padded_per_partition) {
58 return hlo;
59 }
60 // 1. Calculate left_halo size.
61 // left-halo size is 0
62 OffsetCalculation left_halo_size_function =
63 OffsetCalculation(MultiplyAddDivideOffsetCalculation(0, 0, 1));
64
65 // 2. Calculate right_halo size.
66 // D = size_padded_per_partition
67 // S = size_per_partition
68 // i = shard_ordinal
69 // right-halo size is D * (i + 2) - S * (i + 2) = (D - S) * i + 2 * (D - S)
70 OffsetCalculation right_halo_size_function =
71 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
72 size_padded_per_partition - size_per_partition,
73 2 * (size_padded_per_partition - size_per_partition), 1));
74
75 auto concat = hlo;
76 // 3. Halo exchange.
77 auto halo_exchange_result =
78 ExchangeHalo(hlo, left_halo_size_function, right_halo_size_function,
79 hlo->shape().rank() - 1, sharding, collective_ops_creator,
80 next_channel_id, b);
81
82 if (halo_exchange_result.has_value()) {
83 concat = halo_exchange_result.value();
84 } else {
85 return std::nullopt;
86 }
87
88 // 4. Slice the valid result.
89 // Slice offset is (D - S) * i
90 OffsetCalculation start_offset_on_padded_concat_calculation =
91 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
92 size_padded_per_partition - size_per_partition, 0, 1));
93 auto slice_shape = concat->shape();
94 slice_shape.set_dimensions(concat->shape().rank() - 1,
95 size_padded_per_partition);
96 auto zero_s32 =
97 b->AddInstruction(HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
98 std::vector<HloInstruction*> slice_offsets(concat->shape().rank(), zero_s32);
99 auto partition_ordinals =
100 MakeTiledPartitionOrdinals(sharding, partition_id, b);
101 slice_offsets[concat->shape().rank() - 1] =
102 start_offset_on_padded_concat_calculation.Calculate(
103 partition_ordinals[concat->shape().rank() - 1], b);
104 return b->AddInstruction(HloInstruction::CreateDynamicSlice(
105 slice_shape, concat, slice_offsets, slice_shape.dimensions()));
106 }
107
108 // If partition 0 has {0, 1, 2, 3} and num partitions is 2, after shuffling,
109 // the data becomes {0, 2, 1, 3}.
ShuffleWithinEachPartitionUsingOneHot(HloInstruction * hlo,int64_t num_partitions,SpmdBuilder * b)110 HloInstruction* ShuffleWithinEachPartitionUsingOneHot(HloInstruction* hlo,
111 int64_t num_partitions,
112 SpmdBuilder* b) {
113 int64_t size_per_partition = hlo->shape().dimensions().back();
114 CHECK_EQ(size_per_partition % num_partitions, 0);
115 auto indices_iota = b->AddInstruction(HloInstruction::CreateIota(
116 ShapeUtil::MakeShape(S32, {size_per_partition}), 0));
117 auto reshape_indices_iota = b->AddInstruction(HloInstruction::CreateReshape(
118 ShapeUtil::MakeShape(
119 S32, {size_per_partition / num_partitions, num_partitions}),
120 indices_iota));
121 auto transpoe_indices_iota =
122 b->AddInstruction(HloInstruction::CreateTranspose(
123 ShapeUtil::MakeShape(
124 S32, {num_partitions, size_per_partition / num_partitions}),
125 reshape_indices_iota, {1, 0}));
126 auto one_hot_indices = b->AddInstruction(HloInstruction::CreateBroadcast(
127 ShapeUtil::MakeShape(S32, {size_per_partition, size_per_partition}),
128 b->AddInstruction(HloInstruction::CreateReshape(
129 ShapeUtil::MakeShape(S32, {size_per_partition}),
130 transpoe_indices_iota)),
131 /*broadcast_dimensions=*/{1}));
132
133 auto partition_indices = b->AddInstruction(HloInstruction::CreateIota(
134 ShapeUtil::MakeShape(S32, {size_per_partition, size_per_partition}), 0));
135
136 auto shuffle_one_hot = b->AddInstruction(HloInstruction::CreateConvert(
137 ShapeUtil::ChangeElementType(partition_indices->shape(),
138 hlo->shape().element_type()),
139 b->AddInstruction(HloInstruction::CreateCompare(
140 ShapeUtil::ChangeElementType(partition_indices->shape(), PRED),
141 one_hot_indices, partition_indices, ComparisonDirection::kEq))));
142
143 DotDimensionNumbers dot_dnums;
144 dot_dnums.add_lhs_contracting_dimensions(hlo->shape().rank() - 1);
145 dot_dnums.add_rhs_contracting_dimensions(0);
146 PrecisionConfig precision_config;
147 precision_config.mutable_operand_precision()->Resize(
148 2, PrecisionConfig::DEFAULT);
149 HloInstruction* dot = b->AddInstruction(HloInstruction::CreateDot(
150 hlo->shape(), hlo, shuffle_one_hot, dot_dnums, precision_config));
151 return dot;
152 }
153
154 // If partition 0 has {0, 2, 1, 3}, partition 1 has {4, 0, 5, 0} and
155 // num partitions is 2, after all-to-all, partition 0 will have {0, 2, 4, 0}
156 // and partition 1 will have {1, 3, 5, 0}.
ShuffleDataWithAllToAll(HloInstruction * hlo,int64_t num_partitions,const SPMDCollectiveOpsCreator & collective_ops_creator,int64_t * next_channel_id,SpmdBuilder * b)157 HloInstruction* ShuffleDataWithAllToAll(
158 HloInstruction* hlo, int64_t num_partitions,
159 const SPMDCollectiveOpsCreator& collective_ops_creator,
160 int64_t* next_channel_id, SpmdBuilder* b) {
161 std::vector<std::vector<int64_t>> groups(1);
162 std::vector<int64_t> partition_subgroups(num_partitions);
163 std::iota(partition_subgroups.begin(), partition_subgroups.end(), 0);
164 groups[0] = partition_subgroups;
165 auto all_to_all = collective_ops_creator.create_cross_partition_all_to_all(
166 b, {hlo}, groups, (*next_channel_id)++, hlo->shape().rank() - 1);
167 return all_to_all;
168 }
169
GetCorrectionFactor(HloInstruction * hlo,int64_t num_partitions,HloInstruction * partition_id,SpmdBuilder * b)170 HloInstruction* GetCorrectionFactor(HloInstruction* hlo, int64_t num_partitions,
171 HloInstruction* partition_id,
172 SpmdBuilder* b) {
173 /* n = size_per_replica
174 m = num_partitions
175 factor = tf.exp(-2.0j * np.pi * tf.cast(position_index, tf.complex64) *
176 * tf.cast(tf.range(n), dtype=tf.complex64) /
177 (n * m))
178
179 */
180 auto add_hlo = [&](std::unique_ptr<HloInstruction> to_add) {
181 return b->AddInstruction(std::move(to_add));
182 };
183 int64_t per_replica_size = hlo->shape().dimensions().back();
184 auto constant_factor =
185 add_hlo(HloInstruction::CreateConstant(LiteralUtil::CreateR0(
186 complex64(0, -2.0 * M_PI / (num_partitions * per_replica_size)))));
187 constant_factor = add_hlo(HloInstruction::CreateBroadcast(
188 hlo->shape(), constant_factor, /*broadcast_dimensions=*/{}));
189 auto converted_partition_id = add_hlo(HloInstruction::CreateConvert(
190 ShapeUtil::ChangeElementType(partition_id->shape(),
191 hlo->shape().element_type()),
192 partition_id));
193 // TODO(wangtao): multipy before broadcast.
194 auto broadcast_partition_id = add_hlo(HloInstruction::CreateBroadcast(
195 hlo->shape(), converted_partition_id, /*broadcast_dimensions=*/{}));
196 auto exp_operand = add_hlo(
197 HloInstruction::CreateBinary(hlo->shape(), HloOpcode::kMultiply,
198 constant_factor, broadcast_partition_id));
199 auto iota = add_hlo(
200 HloInstruction::CreateIota(hlo->shape(), hlo->shape().rank() - 1));
201 exp_operand = add_hlo(HloInstruction::CreateBinary(
202 hlo->shape(), HloOpcode::kMultiply, exp_operand, iota));
203 return add_hlo(
204 HloInstruction::CreateUnary(hlo->shape(), HloOpcode::kExp, exp_operand));
205 }
206
207 // Sudo code for the while loop:
208 // def body(dest_transform, dest_core_position, source_transform,
209 // source_core_position, i):
210 // factor = tf.exp(-2.0j * np.pi *
211 // tf.cast(dest_core_position, tf.complex64) *
212 // tf.cast(source_core_position, tf.complex64) / num_partitions)
213 // dest_transform += factor * source_transform
214 // source_core_position = tf.raw_ops.CollectivePermute(
215 // input=source_core_position,
216 // source_target_pairs=source_target_pairs,
217 // name='source_core_position_permute')
218 // source_transform = tf.raw_ops.CollectivePermute(
219 // input=source_transform,
220 // source_target_pairs=source_target_pairs,
221 // name='source_transform_permute')
222 // i += 1
223 // return (dest_transform, dest_core_position, source_transform,
224 // source_core_position, i)
GetFinalFftUsingCollectivePermute(HloInstruction * hlo,const HloSharding & sharding,const SPMDCollectiveOpsCreator & collective_ops_creator,int64_t num_partitions,HloInstruction * partition_id,int64_t * next_channel_id,HloModule * module,SpmdBuilder * b)225 HloInstruction* GetFinalFftUsingCollectivePermute(
226 HloInstruction* hlo, const HloSharding& sharding,
227 const SPMDCollectiveOpsCreator& collective_ops_creator,
228 int64_t num_partitions, HloInstruction* partition_id,
229 int64_t* next_channel_id, HloModule* module, SpmdBuilder* b) {
230 auto iteration = b->AddInstruction(
231 HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32_t>(0)));
232 auto converted_partition_id = b->AddInstruction(HloInstruction::CreateConvert(
233 ShapeUtil::ChangeElementType(partition_id->shape(),
234 hlo->shape().element_type()),
235 partition_id));
236 // Buid while loop body.
237 SpmdBuilder body_b("fft_collective_permute_body", hlo);
238 auto param = body_b.AddInstruction(HloInstruction::CreateParameter(
239 /*parameter_number=*/0,
240 ShapeUtil::MakeTupleShape(
241 {hlo->shape(), hlo->shape(), converted_partition_id->shape(),
242 converted_partition_id->shape(), iteration->shape()}),
243 "param"));
244 auto dest_transform = body_b.AddInstruction(
245 HloInstruction::CreateGetTupleElement(hlo->shape(), param, 0));
246 auto source_transform = body_b.AddInstruction(
247 HloInstruction::CreateGetTupleElement(hlo->shape(), param, 1));
248 auto dest_partition_id =
249 body_b.AddInstruction(HloInstruction::CreateGetTupleElement(
250 converted_partition_id->shape(), param, 2));
251 auto source_partition_id =
252 body_b.AddInstruction(HloInstruction::CreateGetTupleElement(
253 converted_partition_id->shape(), param, 3));
254 auto i = body_b.AddInstruction(
255 HloInstruction::CreateGetTupleElement(iteration->shape(), param, 4));
256 /*
257 factor = tf.exp(-2.0j * np.pi *
258 tf.cast(dest_partiton_id, tf.complex64) *
259 tf.cast(source_partition_id, tf.complex64) /
260 num_partitions) dest_transform += factor * source_transform
261 */
262 auto constant_factor = body_b.AddInstruction(HloInstruction::CreateConstant(
263 LiteralUtil::CreateR0(complex64(0, -2.0 * M_PI / num_partitions))));
264
265 constant_factor = body_b.AddInstruction(HloInstruction::CreateBinary(
266 constant_factor->shape(), HloOpcode::kMultiply, constant_factor,
267 dest_partition_id));
268 constant_factor = body_b.AddInstruction(HloInstruction::CreateBinary(
269 constant_factor->shape(), HloOpcode::kMultiply, constant_factor,
270 source_partition_id));
271 auto phase_factor = body_b.AddInstruction(HloInstruction::CreateUnary(
272 constant_factor->shape(), HloOpcode::kExp, constant_factor));
273 phase_factor = body_b.AddInstruction(
274 HloInstruction::CreateBroadcast(hlo->shape(), phase_factor, {}));
275 auto phase_adjust_source_transform =
276 body_b.AddInstruction(HloInstruction::CreateBinary(
277 hlo->shape(), HloOpcode::kMultiply, phase_factor, source_transform));
278 dest_transform = body_b.AddInstruction(HloInstruction::CreateBinary(
279 hlo->shape(), HloOpcode::kAdd, phase_adjust_source_transform,
280 dest_transform));
281 // collective permute for source partition_id and source_transfrom.
282 std::vector<std::pair<int64_t, int64_t>> src_dst_pairs;
283 sharding.tile_assignment().Each(
284 [&](absl::Span<const int64_t> indices, int64_t src_device) {
285 std::vector<int64_t> target_indices(indices.begin(), indices.end());
286 target_indices.back() = (indices.back() + 1) % num_partitions;
287 int64_t dst_device = sharding.tile_assignment()(target_indices);
288 src_dst_pairs.emplace_back(src_device, dst_device);
289 });
290
291 source_partition_id =
292 collective_ops_creator.create_cross_partition_collective_permute(
293 &body_b, source_partition_id, src_dst_pairs, (*next_channel_id)++);
294
295 source_transform =
296 collective_ops_creator.create_cross_partition_collective_permute(
297 &body_b, source_transform, src_dst_pairs, (*next_channel_id)++);
298
299 // ++i
300 i = body_b.AddInstruction(HloInstruction::CreateBinary(
301 i->shape(), HloOpcode::kAdd, i,
302 body_b.AddInstruction(
303 HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32_t>(1)))));
304 body_b.AddInstruction(
305 HloInstruction::CreateTuple({dest_transform, source_transform,
306 dest_partition_id, source_partition_id, i}));
307
308 // Build while loop conditions.
309 auto zero = CreateZero(hlo->shape(), b);
310 SpmdBuilder cond_b("fft_collective_permute_condition", hlo);
311 auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter(
312 /*parameter_number=*/0,
313 ShapeUtil::MakeTupleShape(
314 {hlo->shape(), hlo->shape(), converted_partition_id->shape(),
315 converted_partition_id->shape(), iteration->shape()}),
316 "param"));
317 auto cond_i = cond_b.AddInstruction(
318 HloInstruction::CreateGetTupleElement(iteration->shape(), cond_param, 4));
319 cond_b.AddInstruction(HloInstruction::CreateCompare(
320 ShapeUtil::MakeShape(PRED, {}), cond_i,
321 cond_b.AddInstruction(HloInstruction::CreateConstant(
322 LiteralUtil::CreateR0<uint32_t>(num_partitions))),
323 ComparisonDirection::kLt));
324
325 // Build while loop.
326 auto while_loop = b->AddInstruction(HloInstruction::CreateWhile(
327 cond_param->shape(), module->AddEmbeddedComputation(cond_b.Build()),
328 module->AddEmbeddedComputation(body_b.Build()),
329 b->AddInstruction(
330 HloInstruction::CreateTuple({zero, hlo, converted_partition_id,
331 converted_partition_id, iteration}))));
332
333 return b->AddInstruction(
334 HloInstruction::CreateGetTupleElement(hlo->shape(), while_loop, 0));
335 }
336
337 // Slice valid data in each partition.
SliceValidData(HloInstruction * hlo,const Shape & target_shape,SpmdBuilder * b)338 HloInstruction* SliceValidData(HloInstruction* hlo, const Shape& target_shape,
339 SpmdBuilder* b) {
340 std::vector<int64_t> start_indices(target_shape.rank(), 0);
341 std::vector<int64_t> strides(target_shape.rank(), 1);
342 return b->AddInstruction(HloInstruction::CreateSlice(
343 target_shape, hlo, start_indices, target_shape.dimensions(), strides));
344 }
345
346 } // namespace
347
348 // Distributed FFT using the algorithm described in go/tpu-spmd-fft.
HandleFft(HloInstruction * hlo)349 Status SpmdPartitioningVisitor::HandleFft(HloInstruction* hlo) {
350 if (hlo->operand(0)->shape().rank() < 3 || hlo->fft_type() != FftType::FFT) {
351 return DefaultAction(hlo);
352 }
353
354 // Only support input_length equals fft_length's case.
355 int64_t input_length = hlo->operand(0)->shape().dimensions().back();
356 int64_t fft_length = hlo->fft_length().back();
357 if (input_length != fft_length || input_length % num_partitions_ != 0) {
358 return DefaultAction(hlo);
359 }
360
361 // Support partition at the last dimension only.
362 if (!hlo->has_sharding() ||
363 hlo->sharding().tile_assignment().dimensions().back() !=
364 num_partitions_) {
365 return DefaultAction(hlo);
366 }
367
368 auto partitioned_input =
369 GetPartitionedHlo(hlo->operand(0))
370 .PadWithValue(CreateR0WithType(hlo->shape().element_type(), 0, &b_));
371
372 // 1.a. Use right halo exchange to shuffle data first and slice with
373 // valid data. Data shuffling ensures an in-order transform that the sequences
374 // of data before and after the transform are the same. The data shuffling
375 // requires the size of data per partition is divisible by the number of
376 // partitions. For example, If input is {0, 1, 2, 3, 4, 5} and
377 // num partitions is 2, after halo exchange partition 0 has {0, 1, 2, 3} and
378 // partition 1 has {4, 5, 0, 0}, where 0s in the partition 1 are padding data.
379 // Zeros paddings append zeros to the end of the full data.
380 auto result = partitioned_input.hlo();
381 auto padded_hlo = PadEachPartitionWithHaloExchange(
382 partitioned_input.hlo(), num_partitions_, hlo->sharding(),
383 partitioned_input.state().collective_ops_creator,
384 partitioned_input.state().next_channel_id,
385 partitioned_input.state().partition_id, partitioned_input.state().b);
386
387 if (padded_hlo.has_value()) {
388 result = padded_hlo.value();
389 }
390
391 // 1.b Shuffle data within each partition using one hot and matmul.
392 // If partition 0 has {0, 1, 2, 3} and num partitions is 2, after shuffling,
393 // the data becomes {0, 2, 1, 3}.
394 result = ShuffleWithinEachPartitionUsingOneHot(result, num_partitions_,
395 partitioned_input.state().b);
396 // 1.c all-to-all
397 // If partition 0 has {0, 2, 1, 3}, partition 1 has {4, 0, 5, 0} and
398 // num partitions is 2, after all-to-all, partition 0 will have {0, 2, 4, 0}
399 // and partition 1 will have {1, 3, 5, 0}.
400 result = ShuffleDataWithAllToAll(
401 result, num_partitions_, partitioned_input.state().collective_ops_creator,
402 partitioned_input.state().next_channel_id, partitioned_input.state().b);
403 // 1.d Slice valid data in each partition.
404 result = SliceValidData(result, partitioned_input.hlo()->shape(), &b_);
405
406 // 2. Do local fft transform.
407 auto partitioned_fft_length = hlo->fft_length();
408 partitioned_fft_length.back() /= num_partitions_;
409 result = b_.AddInstruction(HloInstruction::CreateFft(
410 result->shape(), result, hlo->fft_type(), partitioned_fft_length));
411
412 // Multiply by correct factor for local phase ajustment.
413 auto correction_factor = GetCorrectionFactor(
414 result, num_partitions_, partitioned_input.state().partition_id,
415 partitioned_input.state().b);
416 result = b_.AddInstruction(HloInstruction::CreateBinary(
417 result->shape(), HloOpcode::kMultiply, result, correction_factor));
418
419 // 3. Second phase FFT with collective permute. fft_length = num_partitions.
420 result = GetFinalFftUsingCollectivePermute(
421 result, hlo->sharding(), partitioned_input.state().collective_ops_creator,
422 num_partitions_, partitioned_input.state().partition_id,
423 partitioned_input.state().next_channel_id, module_,
424 partitioned_input.state().b);
425
426 result->set_sharding(hlo->sharding());
427 auto partitioned_fft =
428 PartitionedHlo(result, hlo->shape(), partitioned_input.state());
429 SetPartitionedHlo(hlo, partitioned_fft);
430 return OkStatus();
431 }
432
433 } // namespace spmd
434 } // namespace xla
435