1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/spmd/custom_call_handler.h"
17
18 #include <vector>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/compiler/xla/client/lib/comparators.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/literal_util.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
30 #include "tensorflow/compiler/xla/service/hlo_lexer.h"
31 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
32 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
33 #include "tensorflow/compiler/xla/service/shape_inference.h"
34 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
35 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/util.h"
38 #include "tensorflow/compiler/xla/window_util.h"
39
40 namespace xla {
41 namespace spmd {
42
43 namespace {
44
ParseOpaqueAsAttributes(const HloInstruction * hlo)45 StatusOr<absl::flat_hash_map<std::string, int64_t>> ParseOpaqueAsAttributes(
46 const HloInstruction* hlo) {
47 absl::string_view opaque = Cast<HloCustomCallInstruction>(hlo)->opaque();
48 HloLexer lexer(opaque);
49 absl::flat_hash_map<std::string, int64_t> result;
50 while (lexer.Lex() != TokKind::kEof) {
51 if (lexer.GetKind() != TokKind::kAttributeName) {
52 return InvalidArgument("Expects attribute name, %s", opaque);
53 }
54 std::string attr_name = lexer.GetStrVal();
55 if (lexer.Lex() != TokKind::kInt) {
56 return InvalidArgument("expects integer attribute value");
57 }
58 result[attr_name] = lexer.GetInt64Val();
59 if (lexer.Lex() != TokKind::kComma) {
60 break;
61 }
62 }
63 return result;
64 }
65
66 constexpr char kSPMDOpRotateRight[] = "_SPMDInternalOp_RotateRight";
67
68 } // namespace
69
HandleCustomCallTopK(HloInstruction * hlo)70 Status SpmdPartitioningVisitor::HandleCustomCallTopK(HloInstruction* hlo) {
71 if (!hlo->operand(0)->has_sharding()) {
72 return DefaultAction(hlo);
73 }
74
75 const HloSharding& sharding = hlo->operand(0)->sharding();
76 // No support for partial replicate yet.
77 if (sharding.IsTileMaximal() || sharding.IsReplicated() ||
78 sharding.ReplicateOnLastTileDim()) {
79 return DefaultAction(hlo);
80 }
81
82 const int64_t batch_dim = 0;
83 const int64_t sort_dim = 1;
84 const int64_t shard_count = sharding.tile_assignment().dim(sort_dim);
85
86 if (shard_count <= 1) {
87 return DefaultAction(hlo);
88 }
89
90 const int64_t batch_dim_partition = sharding.tile_assignment().dim(batch_dim);
91 const int64_t input_size = hlo->operand(0)->shape().dimensions(sort_dim);
92 const int64_t batch_size = hlo->shape().tuple_shapes(0).dimensions(batch_dim);
93 const int64_t k = hlo->shape().tuple_shapes(0).dimensions(sort_dim);
94 const int64_t per_partition_size = CeilOfRatio(input_size, shard_count);
95
96 if (k >= per_partition_size) {
97 return DefaultAction(hlo);
98 }
99
100 auto input = hlo->operand(0);
101 const auto element_type = input->shape().element_type();
102
103 auto partitioned_input = GetPartitionedHlo(input).PadWithValue(
104 CreateFirstWithType(element_type, &b_));
105
106 auto partition_state = partitioned_input.state();
107 auto replicated_sharding = HloSharding::Replicate();
108 // If batch dimension is partitioned, partial replicated on sort dimension.
109 if (batch_dim_partition > 1) {
110 auto sharding_grouped =
111 hlo_sharding_util::GroupShardingOnDims(sharding, {batch_dim});
112 partition_state = CreatePerGroupPartitioningState(
113 partitioned_input.state(), sharding_grouped.device_groups,
114 partitioned_input.state().b);
115 auto reshape_tile_assignment = sharding.tile_assignment();
116 auto reshape_dimensions = reshape_tile_assignment.dimensions();
117 reshape_dimensions.push_back(reshape_dimensions.back());
118 reshape_dimensions[sort_dim] = 1;
119 reshape_tile_assignment.Reshape(reshape_dimensions);
120 replicated_sharding = HloSharding::PartialTile(reshape_tile_assignment);
121 }
122
123 // Each partition needs to do TopK separately, thus the base shape
124 // becomes [batch_size, k * shard_count].
125 const Shape replicated_shape = ShapeUtil::MakeTupleShape(
126 {ShapeUtil::MakeShape(hlo->operand(0)->shape().element_type(),
127 {batch_size, k * shard_count}),
128 ShapeUtil::MakeShape(S32, {batch_size, k * shard_count})});
129 auto custom_call_sharding =
130 sharding.GetTupleSharding(replicated_shape).ValueOrDie();
131 auto shard_shape =
132 MakePartitionedShape(replicated_shape, custom_call_sharding);
133 auto topk = b_.AddInstruction(
134 hlo->CloneWithNewOperands(shard_shape, {partitioned_input.hlo()}));
135 topk->set_sharding(custom_call_sharding);
136 // Partition customcall.
137 PartitionedHlo partitioned_topk(topk, replicated_shape,
138 MakePartitioningState());
139 topk = partitioned_topk.hlo();
140
141 // Get value from TopK.
142 HloInstruction* value_gte =
143 b_.AddInstruction(HloInstruction::CreateGetTupleElement(
144 topk->shape().tuple_shapes(0), topk, 0));
145 value_gte->set_sharding(sharding);
146 // Partition GetTupleElement of value.
147 PartitionedHlo value_partitioned_gte(
148 value_gte, partitioned_topk.base_shape().tuple_shapes(0),
149 MakePartitioningState());
150 // Reshard value to be replicated.
151 auto replicated_value_gte =
152 value_partitioned_gte.Reshard(replicated_sharding).hlo();
153
154 // Get index from TopK.
155 HloInstruction* index_gte =
156 b_.AddInstruction(HloInstruction::CreateGetTupleElement(
157 topk->shape().tuple_shapes(1), topk, 1));
158 auto partition_id_s32 = b_.AddInstruction(HloInstruction::CreateConvert(
159 ShapeUtil::MakeShape(S32, partition_id_->shape().dimensions()),
160 partition_state.partition_id));
161 // Add per partition offset to index, index returned from CustomCall always
162 // starts from 0.
163 auto index_offset = b_.AddInstruction(HloInstruction::CreateBroadcast(
164 index_gte->shape(),
165 b_.AddInstruction(HloInstruction::CreateBinary(
166 partition_id_s32->shape(), HloOpcode::kMultiply, partition_id_s32,
167 b_.AddInstruction(HloInstruction::CreateConstant(
168 LiteralUtil::CreateR0<int32_t>(per_partition_size))))),
169 {}));
170 index_gte = b_.AddInstruction(HloInstruction::CreateBinary(
171 index_offset->shape(), HloOpcode::kAdd, index_gte, index_offset));
172 index_gte->set_sharding(sharding);
173 // Parttion GetTupleElement of index.
174 PartitionedHlo index_partitioned_gte(
175 index_gte, partitioned_topk.base_shape().tuple_shapes(1),
176 MakePartitioningState());
177 // Reshard index to be replicated.
178 auto replicated_index_gte =
179 index_partitioned_gte.Reshard(replicated_sharding).hlo();
180
181 // Creates replicated sort to do TopK, the input is value and index pairs
182 // from all the partitions. The reason to use Sort instead of CustomCall TopK
183 // is CustomCall only takes value as input. There will be an extra Gather
184 // to get the correct index if CustomCall is used here.
185
186 // Create comparator for the sort.
187 XlaBuilder b("Sort.Compare");
188 XlaComputation comparator = CreateScalarComparisonComputation(
189 "compare-value-and-index", {input->shape().element_type(), S32}, {Gt, Lt},
190 &b);
191 TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape());
192 HloModuleConfig config(program_shape);
193 TF_ASSIGN_OR_RETURN(auto new_module,
194 HloModule::CreateFromProto(comparator.proto(), config));
195 HloCloneContext context(module_);
196 auto compare_computation =
197 module_->DeepCloneComputation(new_module->entry_computation(), &context);
198 // Each partition needs to do TopK separately, thus the base shape for sort
199 // becomes [ceil(batch_size / batch_dim_partition), k * shard_count].
200 const Shape sort_shape = ShapeUtil::MakeTupleShape(
201 {ShapeUtil::MakeShape(
202 hlo->operand(0)->shape().element_type(),
203 {CeilOfRatio(batch_size, batch_dim_partition), k * shard_count}),
204 ShapeUtil::MakeShape(S32, {CeilOfRatio(batch_size, batch_dim_partition),
205 k * shard_count})});
206 auto sort = b_.AddInstruction(HloInstruction::CreateSort(
207 sort_shape, sort_dim, {replicated_value_gte, replicated_index_gte},
208 compare_computation, true));
209 sort->set_sharding(
210 replicated_sharding.GetTupleSharding(sort->shape()).ValueOrDie());
211 PartitionedHlo replicated_sort(sort, replicated_shape,
212 MakePartitioningState());
213
214 // Slice value and index from top-k for output.
215 HloInstruction* sort_value_gte =
216 b_.AddInstruction(HloInstruction::CreateGetTupleElement(
217 replicated_sort.hlo()->shape().tuple_shapes(0), replicated_sort.hlo(),
218 0));
219 HloInstruction* sort_index_gte =
220 b_.AddInstruction(HloInstruction::CreateGetTupleElement(
221 replicated_sort.hlo()->shape().tuple_shapes(1), replicated_sort.hlo(),
222 1));
223 // Slice value from final sort.
224 HloInstruction* slice_sort_value =
225 SliceFirstK(sort_value_gte, &b_, sort_dim, k);
226 // Slice index from final sort.
227 HloInstruction* slice_index_value =
228 SliceFirstK(sort_index_gte, &b_, sort_dim, k);
229 auto create_tuple = b_.AddInstruction(
230 HloInstruction::CreateTuple({slice_sort_value, slice_index_value}));
231 create_tuple->set_sharding(
232 replicated_sharding.GetTupleSharding(create_tuple->shape()).ValueOrDie());
233 SetPartitionedHlo(
234 hlo, PartitionedHlo(create_tuple, hlo->shape(), MakePartitioningState())
235 .Reshard(hlo->sharding()));
236
237 return OkStatus();
238 }
239
HandleCustomCallSPMDInternal_RotateRight(HloInstruction * hlo)240 Status SpmdPartitioningVisitor::HandleCustomCallSPMDInternal_RotateRight(
241 HloInstruction* hlo) {
242 TF_ASSIGN_OR_RETURN(auto attrs, ParseOpaqueAsAttributes(hlo));
243 auto dim_it = attrs.find("dimension");
244 TF_RET_CHECK(dim_it != attrs.end())
245 << "No dimension attribute in SPMD rotate op";
246 int64_t dim = dim_it->second;
247 auto amount_it = attrs.find("amount");
248 TF_RET_CHECK(amount_it != attrs.end())
249 << "No amount attribute in SPMD rotate op";
250
251 PartitionedHlo input =
252 GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding());
253 const int64_t full_size = hlo->shape().dimensions(dim);
254 const int64_t shard_size = input.hlo()->shape().dimensions(dim);
255
256 // We exclude shards that are entirely padding.
257 const int64_t participating_shards = CeilOfRatio(full_size, shard_size);
258 // The last included shard might still have padding on the right.
259 const int64_t right_padding = participating_shards * shard_size - full_size;
260 int64_t amount = amount_it->second;
261 TF_RET_CHECK(amount >= 0)
262 << "Rotate amount cannot be negative in SPMD rotate op";
263
264 amount %= full_size;
265 if (amount == 0) {
266 SetPartitionedHlo(hlo, input);
267 return OkStatus();
268 }
269
270 // First step: rotate `amount` on padded data. E.g., before
271 // 012|345|678|9__ (_: padding)
272 // after:
273 // 678|9__|012|345 (amount: 6)
274 auto rotate_with_padding = [&](int64_t rotate_amount) {
275 int64_t current_size = 0;
276 std::vector<HloInstruction*> concat_pieces;
277 while (current_size < shard_size) {
278 int64_t shard_distance =
279 CeilOfRatio(rotate_amount - current_size, shard_size);
280 int64_t offset_in_shard =
281 shard_distance * shard_size - rotate_amount + current_size;
282
283 int64_t halo_size =
284 std::min(shard_size - offset_in_shard, shard_size - current_size);
285
286 current_size += halo_size;
287 Shape halo_shape = input.hlo()->shape();
288 halo_shape.set_dimensions(dim, halo_size);
289 HloInstruction* halo = input.hlo();
290 if (halo_size != shard_size) {
291 halo_shape.set_dimensions(dim, halo_size);
292 std::vector<int64_t> slice_starts(hlo->shape().rank(), 0);
293 slice_starts[dim] = offset_in_shard;
294 std::vector<int64_t> slice_limits(
295 input.hlo()->shape().dimensions().begin(),
296 input.hlo()->shape().dimensions().end());
297 slice_limits[dim] = offset_in_shard + halo_size;
298 halo = b_.AddInstruction(HloInstruction::CreateSlice(
299 halo_shape, halo, slice_starts, slice_limits,
300 std::vector<int64_t>(halo_shape.rank(), 1)));
301 }
302 if (shard_distance != 0) {
303 std::vector<std::pair<int64_t, int64_t>> pairs;
304 hlo->sharding().tile_assignment().Each(
305 [&](absl::Span<const int64_t> indices, int64_t device) {
306 if (indices[dim] >= participating_shards) {
307 return;
308 }
309 std::vector<int64_t> dst_idx(indices.begin(), indices.end());
310 dst_idx[dim] += shard_distance;
311 dst_idx[dim] %= participating_shards;
312 pairs.emplace_back(device,
313 hlo->sharding().tile_assignment()(dst_idx));
314 });
315 halo =
316 collective_ops_creator_.create_cross_partition_collective_permute(
317 &b_, halo, pairs, NewChannel());
318 }
319 concat_pieces.push_back(halo);
320 }
321 if (concat_pieces.size() > 1) {
322 return b_.AddInstruction(HloInstruction::CreateConcatenate(
323 input.hlo()->shape(), concat_pieces, dim));
324 }
325 return concat_pieces[0];
326 };
327 HloInstruction* rotated0 = rotate_with_padding(amount);
328 if (right_padding == 0) {
329 SetPartitionedHlo(hlo, [&] { return rotated0; });
330 return OkStatus();
331 }
332
333 // Second step: perform another rotate from input, with `right_padding` added
334 // to `amount`. E.g., before
335 // 012|345|678|9__ (_: padding)
336 // after:
337 // 456|789|__0|123 (amount: 6 + 2)
338 // combine (select) with first step:
339 // 678|9__|012|345
340 // now we get:
341 // 456|789|012|3__
342
343 HloInstruction* rotated1 = rotate_with_padding(
344 (amount + right_padding) % (shard_size * participating_shards));
345 HloInstruction* shard_offset = MakePartitionOffsets(
346 hlo->shape(), hlo->sharding(), MakePartitioningState().partition_id, &b_,
347 {dim})[dim];
348 HloInstruction* iota = b_.AddInstruction(HloInstruction::CreateIota(
349 ShapeUtil::ChangeElementType(rotated0->shape(), S32), dim));
350 HloInstruction* selection_boundary =
351 b_.AddInstruction(HloInstruction::CreateBroadcast(
352 iota->shape(),
353 b_.AddInstruction(HloInstruction::CreateBinary(
354 shard_offset->shape(), HloOpcode::kSubtract,
355 b_.AddInstruction(HloInstruction::CreateConstant(
356 LiteralUtil::CreateR0<int32_t>(amount))),
357 shard_offset)),
358 {}));
359 HloInstruction* pred = b_.AddInstruction(HloInstruction::CreateCompare(
360 ShapeUtil::ChangeElementType(iota->shape(), PRED), iota,
361 selection_boundary, Comparison::Direction::kLt));
362 SetPartitionedHlo(hlo, [&] {
363 return b_.AddInstruction(HloInstruction::CreateTernary(
364 rotated0->shape(), HloOpcode::kSelect, pred, rotated1, rotated0));
365 });
366 return OkStatus();
367 }
368
CreateCustomCallSPMDInternal_RotateRight(HloInstruction * input,int64_t dim,int64_t amount)369 std::unique_ptr<HloInstruction> CreateCustomCallSPMDInternal_RotateRight(
370 HloInstruction* input, int64_t dim, int64_t amount) {
371 std::string opaque = absl::StrCat("dimension=", dim, ",amount=", amount);
372 return HloInstruction::CreateCustomCall(input->shape(), {input},
373 kSPMDOpRotateRight, opaque);
374 }
375
HandleCustomCall(HloInstruction * hlo)376 Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
377 if (hlo->custom_call_target() == "SPMDFullToShardShape") {
378 // This op switches from auto partitioning to manual partitioning.
379 auto input_partitioned = GetPartitionedHlo(hlo->operand(0));
380 if (!EvenlyPartitions(hlo->shape(), input_partitioned.sharding())) {
381 input_partitioned = input_partitioned.PadWithValue(
382 CreateR0WithType(hlo->shape().element_type(), 0, &b_));
383 }
384 auto input = input_partitioned.hlo();
385 CHECK(hlo->sharding().IsManual() || hlo->sharding().IsManualSubgroup());
386 CHECK(ShapeUtil::Compatible(
387 input->shape(), MakePartitionedShape(hlo->shape(), hlo->sharding())));
388 auto copy = b_.AddInstruction(
389 HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
390 SetPartitionedHlo(hlo, [&] { return copy; });
391 return OkStatus();
392 }
393 if (hlo->custom_call_target() == "SPMDShardToFullShape") {
394 // This op switches from manual partitioning to auto partitioning.
395 auto input = GetPartitionedHlo(hlo->operand(0)).hlo();
396 CHECK(input->sharding().IsManual() || input->sharding().IsManualSubgroup());
397 auto copy = b_.AddInstruction(
398 HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
399 CHECK(ShapeUtil::Compatible(
400 copy->shape(), MakePartitionedShape(hlo->shape(), hlo->sharding())));
401 SetPartitionedHlo(hlo, [&] { return copy; });
402 return OkStatus();
403 }
404
405 if (hlo->custom_call_target() == "TopK") {
406 return HandleCustomCallTopK(hlo);
407 }
408
409 if (hlo->custom_call_target() == kSPMDOpRotateRight) {
410 return HandleCustomCallSPMDInternal_RotateRight(hlo);
411 }
412
413 if (hlo->sharding().HasUniqueDevice()) {
414 return HandleSingleDevice(hlo);
415 }
416
417 if (hlo->sharding().IsManual()) {
418 // Handle manual custom calls by just cloning it and apply as sharding what
419 // the system expects, which is UniqueDevice(0).
420 std::vector<HloInstruction*> new_operands;
421 new_operands.reserve(hlo->operands().size());
422 for (HloInstruction* operand : hlo->operands()) {
423 new_operands.push_back(GetPartitionedHlo(operand).hlo());
424 }
425 SetPartitionedHlo(hlo, [&] {
426 auto* instr = b_.AddInstruction(
427 hlo->CloneWithNewOperands(hlo->shape(), new_operands));
428 if (hlo->shape().IsTuple()) {
429 std::vector<HloSharding> subshardings(
430 hlo->sharding().tuple_elements().size(),
431 HloSharding::AssignDevice(0));
432 instr->set_sharding(HloSharding::Tuple(hlo->shape(), subshardings));
433 } else {
434 instr->set_sharding(HloSharding::AssignDevice(0));
435 }
436 return instr;
437 });
438 return OkStatus();
439 }
440
441 return DefaultAction(hlo);
442 }
443
444 } // namespace spmd
445 } // namespace xla
446