xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/spmd/fft_handler.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <float.h>
17 
18 #include <cmath>
19 #include <functional>
20 #include <memory>
21 #include <optional>
22 #include <vector>
23 
24 #include "absl/algorithm/container.h"
25 #include "tensorflow/compiler/xla/client/lib/comparators.h"
26 #include "tensorflow/compiler/xla/comparison_util.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/protobuf_util.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
33 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
34 #include "tensorflow/compiler/xla/service/shape_inference.h"
35 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
36 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/util.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 
41 namespace xla {
42 namespace spmd {
43 
44 namespace {
45 
46 // Pad each partition to have size that is multiplication of num_partitions.
47 // For example, if input is {0, 1, 2, 3, 4, 5} and num_partitions = 2,
48 // after padding, it becomes {0, 1, 2, 3} in partition 0 and {4, 5, 0, 0} in
49 // partition 1.
PadEachPartitionWithHaloExchange(HloInstruction * hlo,int64_t num_partitions,const HloSharding & sharding,const SPMDCollectiveOpsCreator & collective_ops_creator,int64_t * next_channel_id,HloInstruction * partition_id,SpmdBuilder * b)50 std::optional<HloInstruction*> PadEachPartitionWithHaloExchange(
51     HloInstruction* hlo, int64_t num_partitions, const HloSharding& sharding,
52     const SPMDCollectiveOpsCreator& collective_ops_creator,
53     int64_t* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) {
54   int64_t size_per_partition = hlo->shape().dimensions().back();
55   int64_t size_padded_per_partition =
56       CeilOfRatio(size_per_partition, num_partitions) * num_partitions;
57   if (size_per_partition == size_padded_per_partition) {
58     return hlo;
59   }
60   // 1. Calculate left_halo size.
61   // left-halo size is 0
62   OffsetCalculation left_halo_size_function =
63       OffsetCalculation(MultiplyAddDivideOffsetCalculation(0, 0, 1));
64 
65   // 2. Calculate right_halo size.
66   // D = size_padded_per_partition
67   // S = size_per_partition
68   // i = shard_ordinal
69   // right-halo size is D * (i + 2) - S * (i + 2) = (D - S) * i + 2 * (D - S)
70   OffsetCalculation right_halo_size_function =
71       OffsetCalculation(MultiplyAddDivideOffsetCalculation(
72           size_padded_per_partition - size_per_partition,
73           2 * (size_padded_per_partition - size_per_partition), 1));
74 
75   auto concat = hlo;
76   // 3. Halo exchange.
77   auto halo_exchange_result =
78       ExchangeHalo(hlo, left_halo_size_function, right_halo_size_function,
79                    hlo->shape().rank() - 1, sharding, collective_ops_creator,
80                    next_channel_id, b);
81 
82   if (halo_exchange_result.has_value()) {
83     concat = halo_exchange_result.value();
84   } else {
85     return std::nullopt;
86   }
87 
88   // 4. Slice the valid result.
89   // Slice offset is (D - S) * i
90   OffsetCalculation start_offset_on_padded_concat_calculation =
91       OffsetCalculation(MultiplyAddDivideOffsetCalculation(
92           size_padded_per_partition - size_per_partition, 0, 1));
93   auto slice_shape = concat->shape();
94   slice_shape.set_dimensions(concat->shape().rank() - 1,
95                              size_padded_per_partition);
96   auto zero_s32 =
97       b->AddInstruction(HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
98   std::vector<HloInstruction*> slice_offsets(concat->shape().rank(), zero_s32);
99   auto partition_ordinals =
100       MakeTiledPartitionOrdinals(sharding, partition_id, b);
101   slice_offsets[concat->shape().rank() - 1] =
102       start_offset_on_padded_concat_calculation.Calculate(
103           partition_ordinals[concat->shape().rank() - 1], b);
104   return b->AddInstruction(HloInstruction::CreateDynamicSlice(
105       slice_shape, concat, slice_offsets, slice_shape.dimensions()));
106 }
107 
108 // If partition 0 has {0, 1, 2, 3} and num partitions is 2, after shuffling,
109 // the data becomes {0, 2, 1, 3}.
ShuffleWithinEachPartitionUsingOneHot(HloInstruction * hlo,int64_t num_partitions,SpmdBuilder * b)110 HloInstruction* ShuffleWithinEachPartitionUsingOneHot(HloInstruction* hlo,
111                                                       int64_t num_partitions,
112                                                       SpmdBuilder* b) {
113   int64_t size_per_partition = hlo->shape().dimensions().back();
114   CHECK_EQ(size_per_partition % num_partitions, 0);
115   auto indices_iota = b->AddInstruction(HloInstruction::CreateIota(
116       ShapeUtil::MakeShape(S32, {size_per_partition}), 0));
117   auto reshape_indices_iota = b->AddInstruction(HloInstruction::CreateReshape(
118       ShapeUtil::MakeShape(
119           S32, {size_per_partition / num_partitions, num_partitions}),
120       indices_iota));
121   auto transpoe_indices_iota =
122       b->AddInstruction(HloInstruction::CreateTranspose(
123           ShapeUtil::MakeShape(
124               S32, {num_partitions, size_per_partition / num_partitions}),
125           reshape_indices_iota, {1, 0}));
126   auto one_hot_indices = b->AddInstruction(HloInstruction::CreateBroadcast(
127       ShapeUtil::MakeShape(S32, {size_per_partition, size_per_partition}),
128       b->AddInstruction(HloInstruction::CreateReshape(
129           ShapeUtil::MakeShape(S32, {size_per_partition}),
130           transpoe_indices_iota)),
131       /*broadcast_dimensions=*/{1}));
132 
133   auto partition_indices = b->AddInstruction(HloInstruction::CreateIota(
134       ShapeUtil::MakeShape(S32, {size_per_partition, size_per_partition}), 0));
135 
136   auto shuffle_one_hot = b->AddInstruction(HloInstruction::CreateConvert(
137       ShapeUtil::ChangeElementType(partition_indices->shape(),
138                                    hlo->shape().element_type()),
139       b->AddInstruction(HloInstruction::CreateCompare(
140           ShapeUtil::ChangeElementType(partition_indices->shape(), PRED),
141           one_hot_indices, partition_indices, ComparisonDirection::kEq))));
142 
143   DotDimensionNumbers dot_dnums;
144   dot_dnums.add_lhs_contracting_dimensions(hlo->shape().rank() - 1);
145   dot_dnums.add_rhs_contracting_dimensions(0);
146   PrecisionConfig precision_config;
147   precision_config.mutable_operand_precision()->Resize(
148       2, PrecisionConfig::DEFAULT);
149   HloInstruction* dot = b->AddInstruction(HloInstruction::CreateDot(
150       hlo->shape(), hlo, shuffle_one_hot, dot_dnums, precision_config));
151   return dot;
152 }
153 
154 // If partition 0 has {0, 2, 1, 3}, partition 1 has {4, 0, 5, 0} and
155 // num partitions is 2, after all-to-all, partition 0 will have {0, 2, 4, 0}
156 // and partition 1 will have {1, 3, 5, 0}.
ShuffleDataWithAllToAll(HloInstruction * hlo,int64_t num_partitions,const SPMDCollectiveOpsCreator & collective_ops_creator,int64_t * next_channel_id,SpmdBuilder * b)157 HloInstruction* ShuffleDataWithAllToAll(
158     HloInstruction* hlo, int64_t num_partitions,
159     const SPMDCollectiveOpsCreator& collective_ops_creator,
160     int64_t* next_channel_id, SpmdBuilder* b) {
161   std::vector<std::vector<int64_t>> groups(1);
162   std::vector<int64_t> partition_subgroups(num_partitions);
163   std::iota(partition_subgroups.begin(), partition_subgroups.end(), 0);
164   groups[0] = partition_subgroups;
165   auto all_to_all = collective_ops_creator.create_cross_partition_all_to_all(
166       b, {hlo}, groups, (*next_channel_id)++, hlo->shape().rank() - 1);
167   return all_to_all;
168 }
169 
GetCorrectionFactor(HloInstruction * hlo,int64_t num_partitions,HloInstruction * partition_id,SpmdBuilder * b)170 HloInstruction* GetCorrectionFactor(HloInstruction* hlo, int64_t num_partitions,
171                                     HloInstruction* partition_id,
172                                     SpmdBuilder* b) {
173   /* n = size_per_replica
174      m = num_partitions
175   factor = tf.exp(-2.0j * np.pi * tf.cast(position_index, tf.complex64) *
176                     * tf.cast(tf.range(n), dtype=tf.complex64) /
177                     (n * m))
178 
179   */
180   auto add_hlo = [&](std::unique_ptr<HloInstruction> to_add) {
181     return b->AddInstruction(std::move(to_add));
182   };
183   int64_t per_replica_size = hlo->shape().dimensions().back();
184   auto constant_factor =
185       add_hlo(HloInstruction::CreateConstant(LiteralUtil::CreateR0(
186           complex64(0, -2.0 * M_PI / (num_partitions * per_replica_size)))));
187   constant_factor = add_hlo(HloInstruction::CreateBroadcast(
188       hlo->shape(), constant_factor, /*broadcast_dimensions=*/{}));
189   auto converted_partition_id = add_hlo(HloInstruction::CreateConvert(
190       ShapeUtil::ChangeElementType(partition_id->shape(),
191                                    hlo->shape().element_type()),
192       partition_id));
193   // TODO(wangtao): multipy before broadcast.
194   auto broadcast_partition_id = add_hlo(HloInstruction::CreateBroadcast(
195       hlo->shape(), converted_partition_id, /*broadcast_dimensions=*/{}));
196   auto exp_operand = add_hlo(
197       HloInstruction::CreateBinary(hlo->shape(), HloOpcode::kMultiply,
198                                    constant_factor, broadcast_partition_id));
199   auto iota = add_hlo(
200       HloInstruction::CreateIota(hlo->shape(), hlo->shape().rank() - 1));
201   exp_operand = add_hlo(HloInstruction::CreateBinary(
202       hlo->shape(), HloOpcode::kMultiply, exp_operand, iota));
203   return add_hlo(
204       HloInstruction::CreateUnary(hlo->shape(), HloOpcode::kExp, exp_operand));
205 }
206 
207 // Sudo code for the while loop:
208 // def body(dest_transform, dest_core_position, source_transform,
209 //             source_core_position, i):
210 //      factor = tf.exp(-2.0j * np.pi  *
211 //                      tf.cast(dest_core_position, tf.complex64) *
212 //                tf.cast(source_core_position, tf.complex64) / num_partitions)
213 //      dest_transform += factor * source_transform
214 //      source_core_position = tf.raw_ops.CollectivePermute(
215 //          input=source_core_position,
216 //          source_target_pairs=source_target_pairs,
217 //          name='source_core_position_permute')
218 //      source_transform = tf.raw_ops.CollectivePermute(
219 //          input=source_transform,
220 //          source_target_pairs=source_target_pairs,
221 //          name='source_transform_permute')
222 //      i += 1
223 //      return (dest_transform, dest_core_position, source_transform,
224 //              source_core_position, i)
GetFinalFftUsingCollectivePermute(HloInstruction * hlo,const HloSharding & sharding,const SPMDCollectiveOpsCreator & collective_ops_creator,int64_t num_partitions,HloInstruction * partition_id,int64_t * next_channel_id,HloModule * module,SpmdBuilder * b)225 HloInstruction* GetFinalFftUsingCollectivePermute(
226     HloInstruction* hlo, const HloSharding& sharding,
227     const SPMDCollectiveOpsCreator& collective_ops_creator,
228     int64_t num_partitions, HloInstruction* partition_id,
229     int64_t* next_channel_id, HloModule* module, SpmdBuilder* b) {
230   auto iteration = b->AddInstruction(
231       HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32_t>(0)));
232   auto converted_partition_id = b->AddInstruction(HloInstruction::CreateConvert(
233       ShapeUtil::ChangeElementType(partition_id->shape(),
234                                    hlo->shape().element_type()),
235       partition_id));
236   // Buid while loop body.
237   SpmdBuilder body_b("fft_collective_permute_body", hlo);
238   auto param = body_b.AddInstruction(HloInstruction::CreateParameter(
239       /*parameter_number=*/0,
240       ShapeUtil::MakeTupleShape(
241           {hlo->shape(), hlo->shape(), converted_partition_id->shape(),
242            converted_partition_id->shape(), iteration->shape()}),
243       "param"));
244   auto dest_transform = body_b.AddInstruction(
245       HloInstruction::CreateGetTupleElement(hlo->shape(), param, 0));
246   auto source_transform = body_b.AddInstruction(
247       HloInstruction::CreateGetTupleElement(hlo->shape(), param, 1));
248   auto dest_partition_id =
249       body_b.AddInstruction(HloInstruction::CreateGetTupleElement(
250           converted_partition_id->shape(), param, 2));
251   auto source_partition_id =
252       body_b.AddInstruction(HloInstruction::CreateGetTupleElement(
253           converted_partition_id->shape(), param, 3));
254   auto i = body_b.AddInstruction(
255       HloInstruction::CreateGetTupleElement(iteration->shape(), param, 4));
256   /*
257     factor = tf.exp(-2.0j * np.pi  *
258                       tf.cast(dest_partiton_id, tf.complex64) *
259                       tf.cast(source_partition_id, tf.complex64) /
260     num_partitions) dest_transform += factor * source_transform
261   */
262   auto constant_factor = body_b.AddInstruction(HloInstruction::CreateConstant(
263       LiteralUtil::CreateR0(complex64(0, -2.0 * M_PI / num_partitions))));
264 
265   constant_factor = body_b.AddInstruction(HloInstruction::CreateBinary(
266       constant_factor->shape(), HloOpcode::kMultiply, constant_factor,
267       dest_partition_id));
268   constant_factor = body_b.AddInstruction(HloInstruction::CreateBinary(
269       constant_factor->shape(), HloOpcode::kMultiply, constant_factor,
270       source_partition_id));
271   auto phase_factor = body_b.AddInstruction(HloInstruction::CreateUnary(
272       constant_factor->shape(), HloOpcode::kExp, constant_factor));
273   phase_factor = body_b.AddInstruction(
274       HloInstruction::CreateBroadcast(hlo->shape(), phase_factor, {}));
275   auto phase_adjust_source_transform =
276       body_b.AddInstruction(HloInstruction::CreateBinary(
277           hlo->shape(), HloOpcode::kMultiply, phase_factor, source_transform));
278   dest_transform = body_b.AddInstruction(HloInstruction::CreateBinary(
279       hlo->shape(), HloOpcode::kAdd, phase_adjust_source_transform,
280       dest_transform));
281   // collective permute for source partition_id and source_transfrom.
282   std::vector<std::pair<int64_t, int64_t>> src_dst_pairs;
283   sharding.tile_assignment().Each(
284       [&](absl::Span<const int64_t> indices, int64_t src_device) {
285         std::vector<int64_t> target_indices(indices.begin(), indices.end());
286         target_indices.back() = (indices.back() + 1) % num_partitions;
287         int64_t dst_device = sharding.tile_assignment()(target_indices);
288         src_dst_pairs.emplace_back(src_device, dst_device);
289       });
290 
291   source_partition_id =
292       collective_ops_creator.create_cross_partition_collective_permute(
293           &body_b, source_partition_id, src_dst_pairs, (*next_channel_id)++);
294 
295   source_transform =
296       collective_ops_creator.create_cross_partition_collective_permute(
297           &body_b, source_transform, src_dst_pairs, (*next_channel_id)++);
298 
299   // ++i
300   i = body_b.AddInstruction(HloInstruction::CreateBinary(
301       i->shape(), HloOpcode::kAdd, i,
302       body_b.AddInstruction(
303           HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32_t>(1)))));
304   body_b.AddInstruction(
305       HloInstruction::CreateTuple({dest_transform, source_transform,
306                                    dest_partition_id, source_partition_id, i}));
307 
308   // Build while loop conditions.
309   auto zero = CreateZero(hlo->shape(), b);
310   SpmdBuilder cond_b("fft_collective_permute_condition", hlo);
311   auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter(
312       /*parameter_number=*/0,
313       ShapeUtil::MakeTupleShape(
314           {hlo->shape(), hlo->shape(), converted_partition_id->shape(),
315            converted_partition_id->shape(), iteration->shape()}),
316       "param"));
317   auto cond_i = cond_b.AddInstruction(
318       HloInstruction::CreateGetTupleElement(iteration->shape(), cond_param, 4));
319   cond_b.AddInstruction(HloInstruction::CreateCompare(
320       ShapeUtil::MakeShape(PRED, {}), cond_i,
321       cond_b.AddInstruction(HloInstruction::CreateConstant(
322           LiteralUtil::CreateR0<uint32_t>(num_partitions))),
323       ComparisonDirection::kLt));
324 
325   // Build while loop.
326   auto while_loop = b->AddInstruction(HloInstruction::CreateWhile(
327       cond_param->shape(), module->AddEmbeddedComputation(cond_b.Build()),
328       module->AddEmbeddedComputation(body_b.Build()),
329       b->AddInstruction(
330           HloInstruction::CreateTuple({zero, hlo, converted_partition_id,
331                                        converted_partition_id, iteration}))));
332 
333   return b->AddInstruction(
334       HloInstruction::CreateGetTupleElement(hlo->shape(), while_loop, 0));
335 }
336 
337 // Slice valid data in each partition.
SliceValidData(HloInstruction * hlo,const Shape & target_shape,SpmdBuilder * b)338 HloInstruction* SliceValidData(HloInstruction* hlo, const Shape& target_shape,
339                                SpmdBuilder* b) {
340   std::vector<int64_t> start_indices(target_shape.rank(), 0);
341   std::vector<int64_t> strides(target_shape.rank(), 1);
342   return b->AddInstruction(HloInstruction::CreateSlice(
343       target_shape, hlo, start_indices, target_shape.dimensions(), strides));
344 }
345 
346 }  // namespace
347 
348 // Distributed FFT using the algorithm described in go/tpu-spmd-fft.
HandleFft(HloInstruction * hlo)349 Status SpmdPartitioningVisitor::HandleFft(HloInstruction* hlo) {
350   if (hlo->operand(0)->shape().rank() < 3 || hlo->fft_type() != FftType::FFT) {
351     return DefaultAction(hlo);
352   }
353 
354   // Only support input_length equals fft_length's case.
355   int64_t input_length = hlo->operand(0)->shape().dimensions().back();
356   int64_t fft_length = hlo->fft_length().back();
357   if (input_length != fft_length || input_length % num_partitions_ != 0) {
358     return DefaultAction(hlo);
359   }
360 
361   // Support partition at the last dimension only.
362   if (!hlo->has_sharding() ||
363       hlo->sharding().tile_assignment().dimensions().back() !=
364           num_partitions_) {
365     return DefaultAction(hlo);
366   }
367 
368   auto partitioned_input =
369       GetPartitionedHlo(hlo->operand(0))
370           .PadWithValue(CreateR0WithType(hlo->shape().element_type(), 0, &b_));
371 
372   // 1.a. Use right halo exchange to shuffle data first and slice with
373   // valid data. Data shuffling ensures an in-order transform that the sequences
374   // of data before and after the transform are the same. The data shuffling
375   // requires the size of data per partition is divisible by the number of
376   // partitions. For example, If input is {0, 1, 2, 3, 4, 5} and
377   // num partitions is 2, after halo exchange partition 0 has {0, 1, 2, 3} and
378   // partition 1 has {4, 5, 0, 0}, where 0s in the partition 1 are padding data.
379   // Zeros paddings append zeros to the end of the full data.
380   auto result = partitioned_input.hlo();
381   auto padded_hlo = PadEachPartitionWithHaloExchange(
382       partitioned_input.hlo(), num_partitions_, hlo->sharding(),
383       partitioned_input.state().collective_ops_creator,
384       partitioned_input.state().next_channel_id,
385       partitioned_input.state().partition_id, partitioned_input.state().b);
386 
387   if (padded_hlo.has_value()) {
388     result = padded_hlo.value();
389   }
390 
391   // 1.b Shuffle data within each partition using one hot and matmul.
392   // If partition 0 has {0, 1, 2, 3} and num partitions is 2, after shuffling,
393   // the data becomes {0, 2, 1, 3}.
394   result = ShuffleWithinEachPartitionUsingOneHot(result, num_partitions_,
395                                                  partitioned_input.state().b);
396   // 1.c all-to-all
397   // If partition 0 has {0, 2, 1, 3}, partition 1 has {4, 0, 5, 0} and
398   // num partitions is 2, after all-to-all, partition 0 will have {0, 2, 4, 0}
399   // and partition 1 will have {1, 3, 5, 0}.
400   result = ShuffleDataWithAllToAll(
401       result, num_partitions_, partitioned_input.state().collective_ops_creator,
402       partitioned_input.state().next_channel_id, partitioned_input.state().b);
403   // 1.d Slice valid data in each partition.
404   result = SliceValidData(result, partitioned_input.hlo()->shape(), &b_);
405 
406   // 2. Do local fft transform.
407   auto partitioned_fft_length = hlo->fft_length();
408   partitioned_fft_length.back() /= num_partitions_;
409   result = b_.AddInstruction(HloInstruction::CreateFft(
410       result->shape(), result, hlo->fft_type(), partitioned_fft_length));
411 
412   // Multiply by correct factor for local phase ajustment.
413   auto correction_factor = GetCorrectionFactor(
414       result, num_partitions_, partitioned_input.state().partition_id,
415       partitioned_input.state().b);
416   result = b_.AddInstruction(HloInstruction::CreateBinary(
417       result->shape(), HloOpcode::kMultiply, result, correction_factor));
418 
419   // 3. Second phase FFT with collective permute. fft_length = num_partitions.
420   result = GetFinalFftUsingCollectivePermute(
421       result, hlo->sharding(), partitioned_input.state().collective_ops_creator,
422       num_partitions_, partitioned_input.state().partition_id,
423       partitioned_input.state().next_channel_id, module_,
424       partitioned_input.state().b);
425 
426   result->set_sharding(hlo->sharding());
427   auto partitioned_fft =
428       PartitionedHlo(result, hlo->shape(), partitioned_input.state());
429   SetPartitionedHlo(hlo, partitioned_fft);
430   return OkStatus();
431 }
432 
433 }  // namespace spmd
434 }  // namespace xla
435