1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
17
18 #include <queue>
19
20 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
21
22 namespace xla {
23
HloReachabilityMap(absl::Span<const HloInstruction * const> instructions)24 HloReachabilityMap::HloReachabilityMap(
25 absl::Span<const HloInstruction* const> instructions)
26 : size_(instructions.size()) {
27 bit_vectors_.reserve(size_);
28 for (const HloInstruction* hlo : instructions) {
29 indices_[GetKey(hlo)] = bit_vectors_.size();
30 bit_vectors_.emplace_back(size_);
31 }
32 CHECK_EQ(size_, indices_.size()); // instructions should be unique
33 }
34
SetReachabilityToUnion(absl::Span<const HloInstruction * const> inputs,const HloInstruction * instruction)35 bool HloReachabilityMap::SetReachabilityToUnion(
36 absl::Span<const HloInstruction* const> inputs,
37 const HloInstruction* instruction) {
38 Index index = GetIndex(instruction);
39 BitVector& bit_vector = GetBitVector(index);
40 tmp_bit_vector_ = bit_vector;
41 SetReachabilityToUnionHelper(inputs, index);
42 return bit_vector != tmp_bit_vector_;
43 }
44
FastSetReachabilityToUnion(absl::Span<const HloInstruction * const> inputs,const HloInstruction * instruction)45 void HloReachabilityMap::FastSetReachabilityToUnion(
46 absl::Span<const HloInstruction* const> inputs,
47 const HloInstruction* instruction) {
48 Index index = GetIndex(instruction);
49 SetReachabilityToUnionHelper(inputs, index);
50 }
51
FastSetReachabilityToUnion(absl::Span<const Index> input_indices,Index index)52 void HloReachabilityMap::FastSetReachabilityToUnion(
53 absl::Span<const Index> input_indices, Index index) {
54 SetReachabilityToUnionHelper(input_indices, index);
55 }
56
SetReachabilityToUnionHelper(absl::Span<const HloInstruction * const> inputs,Index index)57 void HloReachabilityMap::SetReachabilityToUnionHelper(
58 absl::Span<const HloInstruction* const> inputs, Index index) {
59 absl::InlinedVector<Index, 16> input_indices;
60 input_indices.reserve(inputs.size());
61 for (const HloInstruction* input : inputs) {
62 input_indices.push_back(GetIndex(input));
63 }
64 SetReachabilityToUnionHelper(input_indices, index);
65 }
66
SetReachabilityToUnionHelper(absl::Span<const Index> input_indices,Index index)67 void HloReachabilityMap::SetReachabilityToUnionHelper(
68 absl::Span<const Index> input_indices, Index index) {
69 BitVector& bit_vector = GetBitVector(index);
70 // If instruction is part of inputs, don't reset the bit_vector.
71 if (!absl::c_linear_search(input_indices, index)) {
72 bit_vector.SetToZero();
73 }
74 bit_vector.Set(index.v);
75 for (Index input_index : input_indices) {
76 if (input_index != index) {
77 bit_vector.OrWith(GetBitVector(input_index));
78 }
79 }
80 }
81
Replace(const HloInstruction * original,const HloInstruction * replacement)82 void HloReachabilityMap::Replace(const HloInstruction* original,
83 const HloInstruction* replacement) {
84 if (GetKey(original) == GetKey(replacement)) {
85 return;
86 }
87 indices_[GetKey(replacement)] = GetIndex(original).v;
88 indices_.erase(GetKey(original));
89 }
90
SetReachable(Index a,Index b)91 void HloReachabilityMap::SetReachable(Index a, Index b) {
92 GetBitVector(b).Set(a.v);
93 }
94
BuildWithRestrictions(const HloComputation * computation,absl::FunctionRef<void (const HloInstruction *,std::vector<HloInstruction * > *)> add_dependencies)95 std::unique_ptr<HloReachabilityMap> HloReachabilityMap::BuildWithRestrictions(
96 const HloComputation* computation,
97 absl::FunctionRef<void(const HloInstruction*,
98 std::vector<HloInstruction*>*)>
99 add_dependencies) {
100 const auto& all = computation->MakeInstructionPostOrder();
101 auto result = std::make_unique<HloReachabilityMap>(all);
102
103 std::vector<HloInstruction*> inputs;
104 for (const HloInstruction* hlo : all) {
105 inputs.clear();
106 add_dependencies(hlo, &inputs);
107 result->FastSetReachabilityToUnion(inputs, hlo);
108 }
109 return result;
110 }
111
Build(const HloComputation * computation)112 std::unique_ptr<HloReachabilityMap> HloReachabilityMap::Build(
113 const HloComputation* computation) {
114 const auto& all = computation->MakeInstructionPostOrder();
115 auto result = std::make_unique<HloReachabilityMap>(all);
116 auto channel_group = computation->ComputeChannelDependencies();
117
118 std::vector<HloInstruction*> inputs;
119
120 const auto add_input = [&channel_group, &inputs](HloInstruction* input) {
121 inputs.push_back(input);
122 if ((input->opcode() == HloOpcode::kAllReduce ||
123 input->opcode() == HloOpcode::kReduceScatter) &&
124 input->channel_id()) {
125 auto it = channel_group.find(*input->channel_id());
126 if (it != channel_group.end()) {
127 inputs.insert(inputs.end(), it->second.begin(), it->second.end());
128 }
129 }
130 };
131
132 const auto add_dependencies = [&add_input](const HloInstruction* hlo) {
133 for (HloInstruction* operand : hlo->operands()) {
134 add_input(operand);
135 }
136 for (HloInstruction* predecessor : hlo->control_predecessors()) {
137 add_input(predecessor);
138 }
139 };
140
141 for (const HloInstruction* hlo : all) {
142 inputs.clear();
143 add_dependencies(hlo);
144
145 switch (hlo->opcode()) {
146 case HloOpcode::kRecvDone: {
147 auto it = channel_group.find(*hlo->channel_id());
148 if (it != channel_group.end()) {
149 for (HloInstruction* channel : it->second) {
150 if (channel->opcode() == HloOpcode::kSend) {
151 add_input(channel);
152 }
153 }
154 }
155 break;
156 }
157 case HloOpcode::kAllReduce:
158 case HloOpcode::kReduceScatter: {
159 auto channel_id = hlo->channel_id();
160 if (channel_id) {
161 auto it = channel_group.find(channel_id.value());
162 if (it != channel_group.end()) {
163 for (HloInstruction* all_reduce : it->second) {
164 add_dependencies(all_reduce);
165 }
166 }
167 }
168 break;
169 }
170 default:
171 break;
172 }
173
174 result->FastSetReachabilityToUnion(inputs, hlo);
175 }
176 return result;
177 }
178
UpdateReachabilityThroughInstruction(const HloInstruction * instruction)179 void HloReachabilityMap::UpdateReachabilityThroughInstruction(
180 const HloInstruction* instruction) {
181 std::queue<const HloInstruction*> worklist;
182 worklist.push(instruction);
183
184 std::vector<HloInstruction*> inputs;
185
186 while (!worklist.empty()) {
187 const HloInstruction* item = worklist.front();
188 worklist.pop();
189
190 inputs.assign(item->operands().begin(), item->operands().end());
191 inputs.insert(inputs.end(), item->control_predecessors().begin(),
192 item->control_predecessors().end());
193
194 if (SetReachabilityToUnion(inputs, item)) {
195 // Add immediate successors to worklist.
196 for (const HloInstruction* user : item->users()) {
197 worklist.push(user);
198 }
199 for (const HloInstruction* succ : item->control_successors()) {
200 worklist.push(succ);
201 }
202 }
203 }
204 }
205
206 } // namespace xla
207