1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
17
18 #include <optional>
19 #include <vector>
20
21 #include "absl/algorithm/container.h"
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/compiler/xla/layout_util.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/platform/logging.h"
35
36 namespace xla {
InstructionMatchesPattern(HloInstruction * instruction)37 bool AllToAllDecomposer::InstructionMatchesPattern(
38 HloInstruction* instruction) {
39 auto* all_to_all = DynCast<HloAllToAllInstruction>(instruction);
40 if (all_to_all == nullptr) {
41 return false;
42 }
43 // Do not attempt to change layout constrained collectives.
44 if (all_to_all->constrain_layout()) {
45 return false;
46 }
47 if (all_to_all->shape().IsTuple()) {
48 return false;
49 }
50 if (decompose_to_tuple_) {
51 return true;
52 }
53 return all_to_all->shape().rank() < min_array_rank_;
54 }
ExpandInstruction(HloInstruction * instruction)55 StatusOr<HloInstruction*> AllToAllDecomposer::ExpandInstruction(
56 HloInstruction* instruction) {
57 auto* all_to_all = Cast<HloAllToAllInstruction>(instruction);
58 int64_t split_dim = *all_to_all->split_dimension();
59 int64_t all_to_all_group_size =
60 all_to_all->replica_groups().empty()
61 ? instruction->parent()->parent()->config().replica_count()
62 : all_to_all->replica_groups()[0].replica_ids_size();
63 int64_t split_size =
64 all_to_all->shape().dimensions(split_dim) / all_to_all_group_size;
65 if (!decompose_to_tuple_) {
66 Shape new_all_to_all_shape;
67 new_all_to_all_shape.set_element_type(
68 instruction->operand(0)->shape().element_type());
69 for (int64_t i = 0; i < instruction->shape().rank(); ++i) {
70 if (i != split_dim) {
71 new_all_to_all_shape.add_dimensions(all_to_all->shape().dimensions(i));
72 continue;
73 }
74 new_all_to_all_shape.add_dimensions(all_to_all_group_size);
75 new_all_to_all_shape.add_dimensions(split_size);
76 for (int64_t j = all_to_all->shape().rank() + 1; j < min_array_rank_;
77 ++j) {
78 new_all_to_all_shape.add_dimensions(1);
79 }
80 }
81 *(new_all_to_all_shape.mutable_layout()) =
82 LayoutUtil::GetDefaultLayoutForRank(min_array_rank_);
83 HloInstruction* operand_reshape =
84 instruction->parent()->AddInstruction(HloInstruction::CreateReshape(
85 new_all_to_all_shape, instruction->mutable_operand(0)));
86 instruction->SetupDerivedInstruction(operand_reshape);
87 HloInstruction* all_to_all =
88 instruction->parent()->AddInstruction(instruction->CloneWithNewOperands(
89 new_all_to_all_shape, {operand_reshape}));
90 HloInstruction* output_reshape = instruction->parent()->AddInstruction(
91 HloInstruction::CreateReshape(instruction->shape(), all_to_all));
92 instruction->SetupDerivedInstruction(output_reshape);
93 return output_reshape;
94 }
95 DimensionVector slice_starts(all_to_all->shape().rank(), 0);
96 DimensionVector slice_strides(all_to_all->shape().rank(), 1);
97 DimensionVector slice_limits(all_to_all->shape().dimensions().begin(),
98 all_to_all->shape().dimensions().end());
99 slice_limits[split_dim] = split_size;
100 Shape slice_shape = all_to_all->shape();
101 slice_shape.set_dimensions(split_dim, split_size);
102 std::vector<HloInstruction*> slices;
103 slices.reserve(all_to_all_group_size);
104 HloInstruction* operand = all_to_all->mutable_operand(0);
105 for (int64_t i = 0; i < all_to_all_group_size; ++i) {
106 slices.push_back(
107 all_to_all->parent()->AddInstruction(HloInstruction::CreateSlice(
108 slice_shape, operand, slice_starts, slice_limits, slice_strides)));
109 all_to_all->SetupDerivedInstruction(slices.back());
110 slice_starts[split_dim] = slice_limits[split_dim];
111 slice_limits[split_dim] += split_size;
112 }
113 Shape all_to_all_shape = ShapeUtil::MakeTupleShapeWithPtrs(
114 std::vector<const Shape*>(all_to_all_group_size, &slice_shape));
115 HloInstruction* new_all_to_all =
116 all_to_all->parent()->AddInstruction(HloInstruction::CreateAllToAll(
117 all_to_all_shape, slices, all_to_all->replica_groups(), false,
118 all_to_all->channel_id(), std::nullopt));
119 std::vector<HloInstruction*> gtes;
120 gtes.reserve(all_to_all_group_size);
121 for (int64_t i = 0; i < all_to_all_group_size; ++i) {
122 gtes.push_back(all_to_all->parent()->AddInstruction(
123 HloInstruction::CreateGetTupleElement(slice_shape, new_all_to_all, i)));
124 all_to_all->SetupDerivedInstruction(new_all_to_all);
125 }
126 HloInstruction* concat = all_to_all->parent()->AddInstruction(
127 HloInstruction::CreateConcatenate(all_to_all->shape(), gtes, split_dim));
128 all_to_all->SetupDerivedInstruction(concat);
129 return concat;
130 }
131
132 } // namespace xla
133