1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_domain_map.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <utility>
22
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "tensorflow/compiler/xla/map_util.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/types.h"
28
29 namespace xla {
30
Create(HloComputation * computation,std::string domain_kind)31 /* static */ StatusOr<std::unique_ptr<HloDomainMap>> HloDomainMap::Create(
32 HloComputation* computation, std::string domain_kind) {
33 auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind)));
34 TF_RETURN_IF_ERROR(domain_map->Populate(computation));
35 return std::move(domain_map);
36 }
37
Create(HloModule * module,std::string domain_kind)38 /* static */ StatusOr<std::unique_ptr<HloDomainMap>> HloDomainMap::Create(
39 HloModule* module, std::string domain_kind) {
40 auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind)));
41 for (HloComputation* computation : module->computations()) {
42 TF_RETURN_IF_ERROR(domain_map->Populate(computation));
43 }
44 return std::move(domain_map);
45 }
46
InSameDomain(const HloInstruction * instruction1,const HloInstruction * instruction2) const47 bool HloDomainMap::InSameDomain(const HloInstruction* instruction1,
48 const HloInstruction* instruction2) const {
49 int64_t domain_id1 = GetDomainId(instruction1);
50 int64_t domain_id2 = GetDomainId(instruction2);
51 return domain_id1 >= 0 && domain_id1 == domain_id2;
52 }
53
GetDomainId(const HloInstruction * instruction) const54 int64_t HloDomainMap::GetDomainId(const HloInstruction* instruction) const {
55 return FindOrDefault(instruction_to_domain_, instruction, -1);
56 }
57
GetDomainMetadataId(const HloInstruction * instruction) const58 int64_t HloDomainMap::GetDomainMetadataId(
59 const HloInstruction* instruction) const {
60 return FindOrDie(domain_metadata_id_, instruction);
61 }
62
TryProcessEmptyDomain(HloInstruction * instruction)63 Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
64 TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain);
65 // We only check operands, so we are sure to not process the empty domain from
66 // both sides.
67 for (HloInstruction* operand : instruction->unique_operands()) {
68 if (IsDomainInstruction(operand)) {
69 auto domain = std::make_unique<DomainMetadata::Domain>();
70 domain->enter_domains.insert(operand);
71 domain->exit_domains.insert(instruction);
72 TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
73 }
74 }
75 if (instruction == instruction->parent()->root_instruction()) {
76 auto domain = std::make_unique<DomainMetadata::Domain>();
77 domain->enter_domains.insert(instruction);
78 TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
79 }
80 return OkStatus();
81 }
82
Populate(HloComputation * computation)83 Status HloDomainMap::Populate(HloComputation* computation) {
84 InstructionOrderMap instructions_post_order;
85 int64_t count = 0;
86 for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
87 instructions_post_order.insert(std::make_pair(instruction, count++));
88 }
89 for (HloInstruction* instruction : computation->instructions()) {
90 if (IsDomainInstruction(instruction)) {
91 // If this is a kDomain of the kind we are currently processing, check
92 // whether this is an "empty domain".
93 TF_RETURN_IF_ERROR(TryProcessEmptyDomain(instruction));
94 continue;
95 }
96 int64_t domain_id = FindOrDefault(instruction_to_domain_, instruction, -1);
97 if (domain_id >= 0) {
98 // We have already processed this instruction.
99 continue;
100 }
101 TF_ASSIGN_OR_RETURN(std::unique_ptr<DomainMetadata::Domain> domain,
102 CreateDomain(instruction, instructions_post_order));
103 TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
104 }
105 TF_RETURN_IF_ERROR(PopulateDomainMetadataMap());
106 return OkStatus();
107 }
108
PopulateDomainMetadataMap()109 Status HloDomainMap::PopulateDomainMetadataMap() {
110 auto hash = [](const DomainMetadata* m) { return m->Hash(); };
111 auto equal = [](const DomainMetadata* a, const DomainMetadata* b) {
112 return a->Matches(*b);
113 };
114 absl::flat_hash_map<const DomainMetadata*, int64_t, decltype(hash),
115 decltype(equal)>
116 domain_metadata(1024, hash, equal);
117
118 for (auto& domain : instruction_domains_) {
119 int64_t domain_metadata_id = -1;
120 if (!domain->enter_domains.empty()) {
121 const HloInstruction* domain_instruction = *domain->enter_domains.begin();
122 domain_metadata_id =
123 domain_metadata
124 .insert({&domain_instruction->user_side_metadata(),
125 domain_metadata.size() + 1})
126 .first->second;
127 } else if (!domain->exit_domains.empty()) {
128 const HloInstruction* domain_instruction = *domain->exit_domains.begin();
129 domain_metadata_id =
130 domain_metadata
131 .insert({&domain_instruction->operand_side_metadata(),
132 domain_metadata.size() + 1})
133 .first->second;
134 } else {
135 domain_metadata_id = 0;
136 }
137 TF_RET_CHECK(domain_metadata_id >= 0);
138 for (HloInstruction* instruction : domain->instructions) {
139 domain_metadata_id_[instruction] = domain_metadata_id;
140 }
141 }
142 return OkStatus();
143 }
144
InsertDomain(std::unique_ptr<DomainMetadata::Domain> domain)145 Status HloDomainMap::InsertDomain(
146 std::unique_ptr<DomainMetadata::Domain> domain) {
147 int64_t domain_id = instruction_domains_.size();
148 instruction_domains_.push_back(std::move(domain));
149 for (HloInstruction* instruction : instruction_domains_.back()->reach_set) {
150 instruction_to_domain_[instruction] = domain_id;
151 }
152 return OkStatus();
153 }
154
ExpandDomain(HloInstruction * instruction,DomainMetadata::Domain * domain) const155 Status HloDomainMap::ExpandDomain(HloInstruction* instruction,
156 DomainMetadata::Domain* domain) const {
157 std::vector<HloInstruction*> in_queue;
158 in_queue.push_back(instruction);
159 while (!in_queue.empty()) {
160 HloInstruction* current_instruction = in_queue.back();
161 in_queue.pop_back();
162 if (domain->reach_set.insert(current_instruction).second) {
163 // We should not be finding instructions with assigned domain here.
164 // If we assigned a domain to the instruction, it means that all the
165 // instructions reached by it, should have a domain as well.
166 int64_t domain_id =
167 FindOrDefault(instruction_to_domain_, current_instruction, -1);
168 TF_RET_CHECK(domain_id < 0)
169 << "Instruction " << current_instruction->ToString()
170 << " already has domain " << domain_id;
171 for (HloInstruction* operand : current_instruction->operands()) {
172 if (IsDomainInstruction(operand)) {
173 // The reach set instruction is a user of the domain instruction
174 // (the instruction sees the kDomain as operand).
175 // IOW the dataflow enters the domain through the kDomain instruction.
176 domain->enter_domains.insert(operand);
177 } else {
178 in_queue.push_back(operand);
179 }
180 }
181 for (HloInstruction* user : current_instruction->users()) {
182 if (IsDomainInstruction(user)) {
183 // The reach set instruction is an operand of the domain instruction
184 // (the instruction sees the kDomain as user).
185 // IOW the dataflow exits the domain through the kDomain instruction.
186 domain->exit_domains.insert(user);
187 } else {
188 in_queue.push_back(user);
189 }
190 }
191 }
192 }
193 return OkStatus();
194 }
195
CreateDomain(HloInstruction * instruction,const InstructionOrderMap & instructions_order) const196 StatusOr<std::unique_ptr<DomainMetadata::Domain>> HloDomainMap::CreateDomain(
197 HloInstruction* instruction,
198 const InstructionOrderMap& instructions_order) const {
199 auto domain = std::make_unique<DomainMetadata::Domain>();
200 TF_RETURN_IF_ERROR(ExpandDomain(instruction, domain.get()));
201 domain->instructions =
202 MakeNonDomainInstructions(domain->reach_set, instructions_order);
203 return std::move(domain);
204 }
205
IsDomainInstruction(const HloInstruction * instruction) const206 bool HloDomainMap::IsDomainInstruction(
207 const HloInstruction* instruction) const {
208 if (instruction->opcode() != HloOpcode::kDomain) {
209 return false;
210 }
211 if (!domain_kind_.empty()) {
212 if (instruction->user_side_metadata().Kind() != domain_kind_) {
213 return false;
214 }
215 // Both user and operand side of the metadata must be of the same kind.
216 CHECK(instruction->operand_side_metadata().Kind() == domain_kind_)
217 << "Instruction " << instruction->ToString()
218 << " has mismatching metadata kinds";
219 }
220 return true;
221 }
222
223 /* static */ std::vector<HloInstruction*>
MakeNonDomainInstructions(const absl::flat_hash_set<HloInstruction * > & instruction_set,const InstructionOrderMap & instructions_order)224 HloDomainMap::MakeNonDomainInstructions(
225 const absl::flat_hash_set<HloInstruction*>& instruction_set,
226 const InstructionOrderMap& instructions_order) {
227 std::vector<HloInstruction*> instructions;
228 instructions.reserve(instruction_set.size());
229 for (HloInstruction* instruction : instruction_set) {
230 if (instruction->opcode() != HloOpcode::kDomain) {
231 instructions.push_back(instruction);
232 }
233 }
234 // sort instructions according to instructions_order
235 absl::c_sort(instructions,
236 [&instructions_order](HloInstruction* a, HloInstruction* b) {
237 return instructions_order.at(a) < instructions_order.at(b);
238 });
239 return instructions;
240 }
241
242 } // namespace xla
243