xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_reachability.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
17 
18 #include <queue>
19 
20 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
21 
22 namespace xla {
23 
HloReachabilityMap(absl::Span<const HloInstruction * const> instructions)24 HloReachabilityMap::HloReachabilityMap(
25     absl::Span<const HloInstruction* const> instructions)
26     : size_(instructions.size()) {
27   bit_vectors_.reserve(size_);
28   for (const HloInstruction* hlo : instructions) {
29     indices_[GetKey(hlo)] = bit_vectors_.size();
30     bit_vectors_.emplace_back(size_);
31   }
32   CHECK_EQ(size_, indices_.size());  // instructions should be unique
33 }
34 
SetReachabilityToUnion(absl::Span<const HloInstruction * const> inputs,const HloInstruction * instruction)35 bool HloReachabilityMap::SetReachabilityToUnion(
36     absl::Span<const HloInstruction* const> inputs,
37     const HloInstruction* instruction) {
38   Index index = GetIndex(instruction);
39   BitVector& bit_vector = GetBitVector(index);
40   tmp_bit_vector_ = bit_vector;
41   SetReachabilityToUnionHelper(inputs, index);
42   return bit_vector != tmp_bit_vector_;
43 }
44 
FastSetReachabilityToUnion(absl::Span<const HloInstruction * const> inputs,const HloInstruction * instruction)45 void HloReachabilityMap::FastSetReachabilityToUnion(
46     absl::Span<const HloInstruction* const> inputs,
47     const HloInstruction* instruction) {
48   Index index = GetIndex(instruction);
49   SetReachabilityToUnionHelper(inputs, index);
50 }
51 
FastSetReachabilityToUnion(absl::Span<const Index> input_indices,Index index)52 void HloReachabilityMap::FastSetReachabilityToUnion(
53     absl::Span<const Index> input_indices, Index index) {
54   SetReachabilityToUnionHelper(input_indices, index);
55 }
56 
SetReachabilityToUnionHelper(absl::Span<const HloInstruction * const> inputs,Index index)57 void HloReachabilityMap::SetReachabilityToUnionHelper(
58     absl::Span<const HloInstruction* const> inputs, Index index) {
59   absl::InlinedVector<Index, 16> input_indices;
60   input_indices.reserve(inputs.size());
61   for (const HloInstruction* input : inputs) {
62     input_indices.push_back(GetIndex(input));
63   }
64   SetReachabilityToUnionHelper(input_indices, index);
65 }
66 
SetReachabilityToUnionHelper(absl::Span<const Index> input_indices,Index index)67 void HloReachabilityMap::SetReachabilityToUnionHelper(
68     absl::Span<const Index> input_indices, Index index) {
69   BitVector& bit_vector = GetBitVector(index);
70   // If instruction is part of inputs, don't reset the bit_vector.
71   if (!absl::c_linear_search(input_indices, index)) {
72     bit_vector.SetToZero();
73   }
74   bit_vector.Set(index.v);
75   for (Index input_index : input_indices) {
76     if (input_index != index) {
77       bit_vector.OrWith(GetBitVector(input_index));
78     }
79   }
80 }
81 
Replace(const HloInstruction * original,const HloInstruction * replacement)82 void HloReachabilityMap::Replace(const HloInstruction* original,
83                                  const HloInstruction* replacement) {
84   if (GetKey(original) == GetKey(replacement)) {
85     return;
86   }
87   indices_[GetKey(replacement)] = GetIndex(original).v;
88   indices_.erase(GetKey(original));
89 }
90 
SetReachable(Index a,Index b)91 void HloReachabilityMap::SetReachable(Index a, Index b) {
92   GetBitVector(b).Set(a.v);
93 }
94 
BuildWithRestrictions(const HloComputation * computation,absl::FunctionRef<void (const HloInstruction *,std::vector<HloInstruction * > *)> add_dependencies)95 std::unique_ptr<HloReachabilityMap> HloReachabilityMap::BuildWithRestrictions(
96     const HloComputation* computation,
97     absl::FunctionRef<void(const HloInstruction*,
98                            std::vector<HloInstruction*>*)>
99         add_dependencies) {
100   const auto& all = computation->MakeInstructionPostOrder();
101   auto result = std::make_unique<HloReachabilityMap>(all);
102 
103   std::vector<HloInstruction*> inputs;
104   for (const HloInstruction* hlo : all) {
105     inputs.clear();
106     add_dependencies(hlo, &inputs);
107     result->FastSetReachabilityToUnion(inputs, hlo);
108   }
109   return result;
110 }
111 
Build(const HloComputation * computation)112 std::unique_ptr<HloReachabilityMap> HloReachabilityMap::Build(
113     const HloComputation* computation) {
114   const auto& all = computation->MakeInstructionPostOrder();
115   auto result = std::make_unique<HloReachabilityMap>(all);
116   auto channel_group = computation->ComputeChannelDependencies();
117 
118   std::vector<HloInstruction*> inputs;
119 
120   const auto add_input = [&channel_group, &inputs](HloInstruction* input) {
121     inputs.push_back(input);
122     if ((input->opcode() == HloOpcode::kAllReduce ||
123          input->opcode() == HloOpcode::kReduceScatter) &&
124         input->channel_id()) {
125       auto it = channel_group.find(*input->channel_id());
126       if (it != channel_group.end()) {
127         inputs.insert(inputs.end(), it->second.begin(), it->second.end());
128       }
129     }
130   };
131 
132   const auto add_dependencies = [&add_input](const HloInstruction* hlo) {
133     for (HloInstruction* operand : hlo->operands()) {
134       add_input(operand);
135     }
136     for (HloInstruction* predecessor : hlo->control_predecessors()) {
137       add_input(predecessor);
138     }
139   };
140 
141   for (const HloInstruction* hlo : all) {
142     inputs.clear();
143     add_dependencies(hlo);
144 
145     switch (hlo->opcode()) {
146       case HloOpcode::kRecvDone: {
147         auto it = channel_group.find(*hlo->channel_id());
148         if (it != channel_group.end()) {
149           for (HloInstruction* channel : it->second) {
150             if (channel->opcode() == HloOpcode::kSend) {
151               add_input(channel);
152             }
153           }
154         }
155         break;
156       }
157       case HloOpcode::kAllReduce:
158       case HloOpcode::kReduceScatter: {
159         auto channel_id = hlo->channel_id();
160         if (channel_id) {
161           auto it = channel_group.find(channel_id.value());
162           if (it != channel_group.end()) {
163             for (HloInstruction* all_reduce : it->second) {
164               add_dependencies(all_reduce);
165             }
166           }
167         }
168         break;
169       }
170       default:
171         break;
172     }
173 
174     result->FastSetReachabilityToUnion(inputs, hlo);
175   }
176   return result;
177 }
178 
UpdateReachabilityThroughInstruction(const HloInstruction * instruction)179 void HloReachabilityMap::UpdateReachabilityThroughInstruction(
180     const HloInstruction* instruction) {
181   std::queue<const HloInstruction*> worklist;
182   worklist.push(instruction);
183 
184   std::vector<HloInstruction*> inputs;
185 
186   while (!worklist.empty()) {
187     const HloInstruction* item = worklist.front();
188     worklist.pop();
189 
190     inputs.assign(item->operands().begin(), item->operands().end());
191     inputs.insert(inputs.end(), item->control_predecessors().begin(),
192                   item->control_predecessors().end());
193 
194     if (SetReachabilityToUnion(inputs, item)) {
195       // Add immediate successors to worklist.
196       for (const HloInstruction* user : item->users()) {
197         worklist.push(user);
198       }
199       for (const HloInstruction* succ : item->control_successors()) {
200         worklist.push(succ);
201       }
202     }
203   }
204 }
205 
206 }  // namespace xla
207