xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/all_to_all_decomposer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
17 
18 #include <optional>
19 #include <vector>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/compiler/xla/layout_util.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/platform/logging.h"
35 
36 namespace xla {
InstructionMatchesPattern(HloInstruction * instruction)37 bool AllToAllDecomposer::InstructionMatchesPattern(
38     HloInstruction* instruction) {
39   auto* all_to_all = DynCast<HloAllToAllInstruction>(instruction);
40   if (all_to_all == nullptr) {
41     return false;
42   }
43   // Do not attempt to change layout constrained collectives.
44   if (all_to_all->constrain_layout()) {
45     return false;
46   }
47   if (all_to_all->shape().IsTuple()) {
48     return false;
49   }
50   if (decompose_to_tuple_) {
51     return true;
52   }
53   return all_to_all->shape().rank() < min_array_rank_;
54 }
ExpandInstruction(HloInstruction * instruction)55 StatusOr<HloInstruction*> AllToAllDecomposer::ExpandInstruction(
56     HloInstruction* instruction) {
57   auto* all_to_all = Cast<HloAllToAllInstruction>(instruction);
58   int64_t split_dim = *all_to_all->split_dimension();
59   int64_t all_to_all_group_size =
60       all_to_all->replica_groups().empty()
61           ? instruction->parent()->parent()->config().replica_count()
62           : all_to_all->replica_groups()[0].replica_ids_size();
63   int64_t split_size =
64       all_to_all->shape().dimensions(split_dim) / all_to_all_group_size;
65   if (!decompose_to_tuple_) {
66     Shape new_all_to_all_shape;
67     new_all_to_all_shape.set_element_type(
68         instruction->operand(0)->shape().element_type());
69     for (int64_t i = 0; i < instruction->shape().rank(); ++i) {
70       if (i != split_dim) {
71         new_all_to_all_shape.add_dimensions(all_to_all->shape().dimensions(i));
72         continue;
73       }
74       new_all_to_all_shape.add_dimensions(all_to_all_group_size);
75       new_all_to_all_shape.add_dimensions(split_size);
76       for (int64_t j = all_to_all->shape().rank() + 1; j < min_array_rank_;
77            ++j) {
78         new_all_to_all_shape.add_dimensions(1);
79       }
80     }
81     *(new_all_to_all_shape.mutable_layout()) =
82         LayoutUtil::GetDefaultLayoutForRank(min_array_rank_);
83     HloInstruction* operand_reshape =
84         instruction->parent()->AddInstruction(HloInstruction::CreateReshape(
85             new_all_to_all_shape, instruction->mutable_operand(0)));
86     instruction->SetupDerivedInstruction(operand_reshape);
87     HloInstruction* all_to_all =
88         instruction->parent()->AddInstruction(instruction->CloneWithNewOperands(
89             new_all_to_all_shape, {operand_reshape}));
90     HloInstruction* output_reshape = instruction->parent()->AddInstruction(
91         HloInstruction::CreateReshape(instruction->shape(), all_to_all));
92     instruction->SetupDerivedInstruction(output_reshape);
93     return output_reshape;
94   }
95   DimensionVector slice_starts(all_to_all->shape().rank(), 0);
96   DimensionVector slice_strides(all_to_all->shape().rank(), 1);
97   DimensionVector slice_limits(all_to_all->shape().dimensions().begin(),
98                                all_to_all->shape().dimensions().end());
99   slice_limits[split_dim] = split_size;
100   Shape slice_shape = all_to_all->shape();
101   slice_shape.set_dimensions(split_dim, split_size);
102   std::vector<HloInstruction*> slices;
103   slices.reserve(all_to_all_group_size);
104   HloInstruction* operand = all_to_all->mutable_operand(0);
105   for (int64_t i = 0; i < all_to_all_group_size; ++i) {
106     slices.push_back(
107         all_to_all->parent()->AddInstruction(HloInstruction::CreateSlice(
108             slice_shape, operand, slice_starts, slice_limits, slice_strides)));
109     all_to_all->SetupDerivedInstruction(slices.back());
110     slice_starts[split_dim] = slice_limits[split_dim];
111     slice_limits[split_dim] += split_size;
112   }
113   Shape all_to_all_shape = ShapeUtil::MakeTupleShapeWithPtrs(
114       std::vector<const Shape*>(all_to_all_group_size, &slice_shape));
115   HloInstruction* new_all_to_all =
116       all_to_all->parent()->AddInstruction(HloInstruction::CreateAllToAll(
117           all_to_all_shape, slices, all_to_all->replica_groups(), false,
118           all_to_all->channel_id(), std::nullopt));
119   std::vector<HloInstruction*> gtes;
120   gtes.reserve(all_to_all_group_size);
121   for (int64_t i = 0; i < all_to_all_group_size; ++i) {
122     gtes.push_back(all_to_all->parent()->AddInstruction(
123         HloInstruction::CreateGetTupleElement(slice_shape, new_all_to_all, i)));
124     all_to_all->SetupDerivedInstruction(new_all_to_all);
125   }
126   HloInstruction* concat = all_to_all->parent()->AddInstruction(
127       HloInstruction::CreateConcatenate(all_to_all->shape(), gtes, split_dim));
128   all_to_all->SetupDerivedInstruction(concat);
129   return concat;
130 }
131 
132 }  // namespace xla
133