xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/spmd/custom_call_handler.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/spmd/custom_call_handler.h"
17 
18 #include <vector>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/compiler/xla/client/lib/comparators.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/literal_util.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
30 #include "tensorflow/compiler/xla/service/hlo_lexer.h"
31 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
32 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
33 #include "tensorflow/compiler/xla/service/shape_inference.h"
34 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
35 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/util.h"
38 #include "tensorflow/compiler/xla/window_util.h"
39 
40 namespace xla {
41 namespace spmd {
42 
43 namespace {
44 
ParseOpaqueAsAttributes(const HloInstruction * hlo)45 StatusOr<absl::flat_hash_map<std::string, int64_t>> ParseOpaqueAsAttributes(
46     const HloInstruction* hlo) {
47   absl::string_view opaque = Cast<HloCustomCallInstruction>(hlo)->opaque();
48   HloLexer lexer(opaque);
49   absl::flat_hash_map<std::string, int64_t> result;
50   while (lexer.Lex() != TokKind::kEof) {
51     if (lexer.GetKind() != TokKind::kAttributeName) {
52       return InvalidArgument("Expects attribute name, %s", opaque);
53     }
54     std::string attr_name = lexer.GetStrVal();
55     if (lexer.Lex() != TokKind::kInt) {
56       return InvalidArgument("expects integer attribute value");
57     }
58     result[attr_name] = lexer.GetInt64Val();
59     if (lexer.Lex() != TokKind::kComma) {
60       break;
61     }
62   }
63   return result;
64 }
65 
66 constexpr char kSPMDOpRotateRight[] = "_SPMDInternalOp_RotateRight";
67 
68 }  // namespace
69 
HandleCustomCallTopK(HloInstruction * hlo)70 Status SpmdPartitioningVisitor::HandleCustomCallTopK(HloInstruction* hlo) {
71   if (!hlo->operand(0)->has_sharding()) {
72     return DefaultAction(hlo);
73   }
74 
75   const HloSharding& sharding = hlo->operand(0)->sharding();
76   // No support for partial replicate yet.
77   if (sharding.IsTileMaximal() || sharding.IsReplicated() ||
78       sharding.ReplicateOnLastTileDim()) {
79     return DefaultAction(hlo);
80   }
81 
82   const int64_t batch_dim = 0;
83   const int64_t sort_dim = 1;
84   const int64_t shard_count = sharding.tile_assignment().dim(sort_dim);
85 
86   if (shard_count <= 1) {
87     return DefaultAction(hlo);
88   }
89 
90   const int64_t batch_dim_partition = sharding.tile_assignment().dim(batch_dim);
91   const int64_t input_size = hlo->operand(0)->shape().dimensions(sort_dim);
92   const int64_t batch_size = hlo->shape().tuple_shapes(0).dimensions(batch_dim);
93   const int64_t k = hlo->shape().tuple_shapes(0).dimensions(sort_dim);
94   const int64_t per_partition_size = CeilOfRatio(input_size, shard_count);
95 
96   if (k >= per_partition_size) {
97     return DefaultAction(hlo);
98   }
99 
100   auto input = hlo->operand(0);
101   const auto element_type = input->shape().element_type();
102 
103   auto partitioned_input = GetPartitionedHlo(input).PadWithValue(
104       CreateFirstWithType(element_type, &b_));
105 
106   auto partition_state = partitioned_input.state();
107   auto replicated_sharding = HloSharding::Replicate();
108   // If batch dimension is partitioned, partial replicated on sort dimension.
109   if (batch_dim_partition > 1) {
110     auto sharding_grouped =
111         hlo_sharding_util::GroupShardingOnDims(sharding, {batch_dim});
112     partition_state = CreatePerGroupPartitioningState(
113         partitioned_input.state(), sharding_grouped.device_groups,
114         partitioned_input.state().b);
115     auto reshape_tile_assignment = sharding.tile_assignment();
116     auto reshape_dimensions = reshape_tile_assignment.dimensions();
117     reshape_dimensions.push_back(reshape_dimensions.back());
118     reshape_dimensions[sort_dim] = 1;
119     reshape_tile_assignment.Reshape(reshape_dimensions);
120     replicated_sharding = HloSharding::PartialTile(reshape_tile_assignment);
121   }
122 
123   // Each partition needs to do TopK separately, thus the base shape
124   // becomes [batch_size, k * shard_count].
125   const Shape replicated_shape = ShapeUtil::MakeTupleShape(
126       {ShapeUtil::MakeShape(hlo->operand(0)->shape().element_type(),
127                             {batch_size, k * shard_count}),
128        ShapeUtil::MakeShape(S32, {batch_size, k * shard_count})});
129   auto custom_call_sharding =
130       sharding.GetTupleSharding(replicated_shape).ValueOrDie();
131   auto shard_shape =
132       MakePartitionedShape(replicated_shape, custom_call_sharding);
133   auto topk = b_.AddInstruction(
134       hlo->CloneWithNewOperands(shard_shape, {partitioned_input.hlo()}));
135   topk->set_sharding(custom_call_sharding);
136   // Partition customcall.
137   PartitionedHlo partitioned_topk(topk, replicated_shape,
138                                   MakePartitioningState());
139   topk = partitioned_topk.hlo();
140 
141   // Get value from TopK.
142   HloInstruction* value_gte =
143       b_.AddInstruction(HloInstruction::CreateGetTupleElement(
144           topk->shape().tuple_shapes(0), topk, 0));
145   value_gte->set_sharding(sharding);
146   // Partition GetTupleElement of value.
147   PartitionedHlo value_partitioned_gte(
148       value_gte, partitioned_topk.base_shape().tuple_shapes(0),
149       MakePartitioningState());
150   // Reshard value to be replicated.
151   auto replicated_value_gte =
152       value_partitioned_gte.Reshard(replicated_sharding).hlo();
153 
154   // Get index from TopK.
155   HloInstruction* index_gte =
156       b_.AddInstruction(HloInstruction::CreateGetTupleElement(
157           topk->shape().tuple_shapes(1), topk, 1));
158   auto partition_id_s32 = b_.AddInstruction(HloInstruction::CreateConvert(
159       ShapeUtil::MakeShape(S32, partition_id_->shape().dimensions()),
160       partition_state.partition_id));
161   // Add per partition offset to index, index returned from CustomCall always
162   // starts from 0.
163   auto index_offset = b_.AddInstruction(HloInstruction::CreateBroadcast(
164       index_gte->shape(),
165       b_.AddInstruction(HloInstruction::CreateBinary(
166           partition_id_s32->shape(), HloOpcode::kMultiply, partition_id_s32,
167           b_.AddInstruction(HloInstruction::CreateConstant(
168               LiteralUtil::CreateR0<int32_t>(per_partition_size))))),
169       {}));
170   index_gte = b_.AddInstruction(HloInstruction::CreateBinary(
171       index_offset->shape(), HloOpcode::kAdd, index_gte, index_offset));
172   index_gte->set_sharding(sharding);
173   // Parttion GetTupleElement of index.
174   PartitionedHlo index_partitioned_gte(
175       index_gte, partitioned_topk.base_shape().tuple_shapes(1),
176       MakePartitioningState());
177   // Reshard index to be replicated.
178   auto replicated_index_gte =
179       index_partitioned_gte.Reshard(replicated_sharding).hlo();
180 
181   // Creates replicated sort to do TopK, the input is value and index pairs
182   // from all the partitions. The reason to use Sort instead of CustomCall TopK
183   // is CustomCall only takes value as input. There will be an extra Gather
184   // to get the correct index if CustomCall is used here.
185 
186   // Create comparator for the sort.
187   XlaBuilder b("Sort.Compare");
188   XlaComputation comparator = CreateScalarComparisonComputation(
189       "compare-value-and-index", {input->shape().element_type(), S32}, {Gt, Lt},
190       &b);
191   TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape());
192   HloModuleConfig config(program_shape);
193   TF_ASSIGN_OR_RETURN(auto new_module,
194                       HloModule::CreateFromProto(comparator.proto(), config));
195   HloCloneContext context(module_);
196   auto compare_computation =
197       module_->DeepCloneComputation(new_module->entry_computation(), &context);
198   // Each partition needs to do TopK separately, thus the base shape for sort
199   // becomes [ceil(batch_size / batch_dim_partition), k * shard_count].
200   const Shape sort_shape = ShapeUtil::MakeTupleShape(
201       {ShapeUtil::MakeShape(
202            hlo->operand(0)->shape().element_type(),
203            {CeilOfRatio(batch_size, batch_dim_partition), k * shard_count}),
204        ShapeUtil::MakeShape(S32, {CeilOfRatio(batch_size, batch_dim_partition),
205                                   k * shard_count})});
206   auto sort = b_.AddInstruction(HloInstruction::CreateSort(
207       sort_shape, sort_dim, {replicated_value_gte, replicated_index_gte},
208       compare_computation, true));
209   sort->set_sharding(
210       replicated_sharding.GetTupleSharding(sort->shape()).ValueOrDie());
211   PartitionedHlo replicated_sort(sort, replicated_shape,
212                                  MakePartitioningState());
213 
214   // Slice value and index from top-k for output.
215   HloInstruction* sort_value_gte =
216       b_.AddInstruction(HloInstruction::CreateGetTupleElement(
217           replicated_sort.hlo()->shape().tuple_shapes(0), replicated_sort.hlo(),
218           0));
219   HloInstruction* sort_index_gte =
220       b_.AddInstruction(HloInstruction::CreateGetTupleElement(
221           replicated_sort.hlo()->shape().tuple_shapes(1), replicated_sort.hlo(),
222           1));
223   // Slice value from final sort.
224   HloInstruction* slice_sort_value =
225       SliceFirstK(sort_value_gte, &b_, sort_dim, k);
226   // Slice index from final sort.
227   HloInstruction* slice_index_value =
228       SliceFirstK(sort_index_gte, &b_, sort_dim, k);
229   auto create_tuple = b_.AddInstruction(
230       HloInstruction::CreateTuple({slice_sort_value, slice_index_value}));
231   create_tuple->set_sharding(
232       replicated_sharding.GetTupleSharding(create_tuple->shape()).ValueOrDie());
233   SetPartitionedHlo(
234       hlo, PartitionedHlo(create_tuple, hlo->shape(), MakePartitioningState())
235                .Reshard(hlo->sharding()));
236 
237   return OkStatus();
238 }
239 
HandleCustomCallSPMDInternal_RotateRight(HloInstruction * hlo)240 Status SpmdPartitioningVisitor::HandleCustomCallSPMDInternal_RotateRight(
241     HloInstruction* hlo) {
242   TF_ASSIGN_OR_RETURN(auto attrs, ParseOpaqueAsAttributes(hlo));
243   auto dim_it = attrs.find("dimension");
244   TF_RET_CHECK(dim_it != attrs.end())
245       << "No dimension attribute in SPMD rotate op";
246   int64_t dim = dim_it->second;
247   auto amount_it = attrs.find("amount");
248   TF_RET_CHECK(amount_it != attrs.end())
249       << "No amount attribute in SPMD rotate op";
250 
251   PartitionedHlo input =
252       GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding());
253   const int64_t full_size = hlo->shape().dimensions(dim);
254   const int64_t shard_size = input.hlo()->shape().dimensions(dim);
255 
256   // We exclude shards that are entirely padding.
257   const int64_t participating_shards = CeilOfRatio(full_size, shard_size);
258   // The last included shard might still have padding on the right.
259   const int64_t right_padding = participating_shards * shard_size - full_size;
260   int64_t amount = amount_it->second;
261   TF_RET_CHECK(amount >= 0)
262       << "Rotate amount cannot be negative in SPMD rotate op";
263 
264   amount %= full_size;
265   if (amount == 0) {
266     SetPartitionedHlo(hlo, input);
267     return OkStatus();
268   }
269 
270   // First step: rotate `amount` on padded data. E.g., before
271   //      012|345|678|9__     (_: padding)
272   // after:
273   //      678|9__|012|345     (amount: 6)
274   auto rotate_with_padding = [&](int64_t rotate_amount) {
275     int64_t current_size = 0;
276     std::vector<HloInstruction*> concat_pieces;
277     while (current_size < shard_size) {
278       int64_t shard_distance =
279           CeilOfRatio(rotate_amount - current_size, shard_size);
280       int64_t offset_in_shard =
281           shard_distance * shard_size - rotate_amount + current_size;
282 
283       int64_t halo_size =
284           std::min(shard_size - offset_in_shard, shard_size - current_size);
285 
286       current_size += halo_size;
287       Shape halo_shape = input.hlo()->shape();
288       halo_shape.set_dimensions(dim, halo_size);
289       HloInstruction* halo = input.hlo();
290       if (halo_size != shard_size) {
291         halo_shape.set_dimensions(dim, halo_size);
292         std::vector<int64_t> slice_starts(hlo->shape().rank(), 0);
293         slice_starts[dim] = offset_in_shard;
294         std::vector<int64_t> slice_limits(
295             input.hlo()->shape().dimensions().begin(),
296             input.hlo()->shape().dimensions().end());
297         slice_limits[dim] = offset_in_shard + halo_size;
298         halo = b_.AddInstruction(HloInstruction::CreateSlice(
299             halo_shape, halo, slice_starts, slice_limits,
300             std::vector<int64_t>(halo_shape.rank(), 1)));
301       }
302       if (shard_distance != 0) {
303         std::vector<std::pair<int64_t, int64_t>> pairs;
304         hlo->sharding().tile_assignment().Each(
305             [&](absl::Span<const int64_t> indices, int64_t device) {
306               if (indices[dim] >= participating_shards) {
307                 return;
308               }
309               std::vector<int64_t> dst_idx(indices.begin(), indices.end());
310               dst_idx[dim] += shard_distance;
311               dst_idx[dim] %= participating_shards;
312               pairs.emplace_back(device,
313                                  hlo->sharding().tile_assignment()(dst_idx));
314             });
315         halo =
316             collective_ops_creator_.create_cross_partition_collective_permute(
317                 &b_, halo, pairs, NewChannel());
318       }
319       concat_pieces.push_back(halo);
320     }
321     if (concat_pieces.size() > 1) {
322       return b_.AddInstruction(HloInstruction::CreateConcatenate(
323           input.hlo()->shape(), concat_pieces, dim));
324     }
325     return concat_pieces[0];
326   };
327   HloInstruction* rotated0 = rotate_with_padding(amount);
328   if (right_padding == 0) {
329     SetPartitionedHlo(hlo, [&] { return rotated0; });
330     return OkStatus();
331   }
332 
333   // Second step: perform another rotate from input, with `right_padding` added
334   // to `amount`. E.g., before
335   //      012|345|678|9__     (_: padding)
336   // after:
337   //      456|789|__0|123     (amount: 6 + 2)
338   // combine (select) with first step:
339   //      678|9__|012|345
340   // now we get:
341   //      456|789|012|3__
342 
343   HloInstruction* rotated1 = rotate_with_padding(
344       (amount + right_padding) % (shard_size * participating_shards));
345   HloInstruction* shard_offset = MakePartitionOffsets(
346       hlo->shape(), hlo->sharding(), MakePartitioningState().partition_id, &b_,
347       {dim})[dim];
348   HloInstruction* iota = b_.AddInstruction(HloInstruction::CreateIota(
349       ShapeUtil::ChangeElementType(rotated0->shape(), S32), dim));
350   HloInstruction* selection_boundary =
351       b_.AddInstruction(HloInstruction::CreateBroadcast(
352           iota->shape(),
353           b_.AddInstruction(HloInstruction::CreateBinary(
354               shard_offset->shape(), HloOpcode::kSubtract,
355               b_.AddInstruction(HloInstruction::CreateConstant(
356                   LiteralUtil::CreateR0<int32_t>(amount))),
357               shard_offset)),
358           {}));
359   HloInstruction* pred = b_.AddInstruction(HloInstruction::CreateCompare(
360       ShapeUtil::ChangeElementType(iota->shape(), PRED), iota,
361       selection_boundary, Comparison::Direction::kLt));
362   SetPartitionedHlo(hlo, [&] {
363     return b_.AddInstruction(HloInstruction::CreateTernary(
364         rotated0->shape(), HloOpcode::kSelect, pred, rotated1, rotated0));
365   });
366   return OkStatus();
367 }
368 
CreateCustomCallSPMDInternal_RotateRight(HloInstruction * input,int64_t dim,int64_t amount)369 std::unique_ptr<HloInstruction> CreateCustomCallSPMDInternal_RotateRight(
370     HloInstruction* input, int64_t dim, int64_t amount) {
371   std::string opaque = absl::StrCat("dimension=", dim, ",amount=", amount);
372   return HloInstruction::CreateCustomCall(input->shape(), {input},
373                                           kSPMDOpRotateRight, opaque);
374 }
375 
HandleCustomCall(HloInstruction * hlo)376 Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
377   if (hlo->custom_call_target() == "SPMDFullToShardShape") {
378     // This op switches from auto partitioning to manual partitioning.
379     auto input_partitioned = GetPartitionedHlo(hlo->operand(0));
380     if (!EvenlyPartitions(hlo->shape(), input_partitioned.sharding())) {
381       input_partitioned = input_partitioned.PadWithValue(
382           CreateR0WithType(hlo->shape().element_type(), 0, &b_));
383     }
384     auto input = input_partitioned.hlo();
385     CHECK(hlo->sharding().IsManual() || hlo->sharding().IsManualSubgroup());
386     CHECK(ShapeUtil::Compatible(
387         input->shape(), MakePartitionedShape(hlo->shape(), hlo->sharding())));
388     auto copy = b_.AddInstruction(
389         HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
390     SetPartitionedHlo(hlo, [&] { return copy; });
391     return OkStatus();
392   }
393   if (hlo->custom_call_target() == "SPMDShardToFullShape") {
394     // This op switches from manual partitioning to auto partitioning.
395     auto input = GetPartitionedHlo(hlo->operand(0)).hlo();
396     CHECK(input->sharding().IsManual() || input->sharding().IsManualSubgroup());
397     auto copy = b_.AddInstruction(
398         HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
399     CHECK(ShapeUtil::Compatible(
400         copy->shape(), MakePartitionedShape(hlo->shape(), hlo->sharding())));
401     SetPartitionedHlo(hlo, [&] { return copy; });
402     return OkStatus();
403   }
404 
405   if (hlo->custom_call_target() == "TopK") {
406     return HandleCustomCallTopK(hlo);
407   }
408 
409   if (hlo->custom_call_target() == kSPMDOpRotateRight) {
410     return HandleCustomCallSPMDInternal_RotateRight(hlo);
411   }
412 
413   if (hlo->sharding().HasUniqueDevice()) {
414     return HandleSingleDevice(hlo);
415   }
416 
417   if (hlo->sharding().IsManual()) {
418     // Handle manual custom calls by just cloning it and apply as sharding what
419     // the system expects, which is UniqueDevice(0).
420     std::vector<HloInstruction*> new_operands;
421     new_operands.reserve(hlo->operands().size());
422     for (HloInstruction* operand : hlo->operands()) {
423       new_operands.push_back(GetPartitionedHlo(operand).hlo());
424     }
425     SetPartitionedHlo(hlo, [&] {
426       auto* instr = b_.AddInstruction(
427           hlo->CloneWithNewOperands(hlo->shape(), new_operands));
428       if (hlo->shape().IsTuple()) {
429         std::vector<HloSharding> subshardings(
430             hlo->sharding().tuple_elements().size(),
431             HloSharding::AssignDevice(0));
432         instr->set_sharding(HloSharding::Tuple(hlo->shape(), subshardings));
433       } else {
434         instr->set_sharding(HloSharding::AssignDevice(0));
435       }
436       return instr;
437     });
438     return OkStatus();
439   }
440 
441   return DefaultAction(hlo);
442 }
443 
444 }  // namespace spmd
445 }  // namespace xla
446