xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_domain_map.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_domain_map.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "tensorflow/compiler/xla/map_util.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/types.h"
28 
29 namespace xla {
30 
Create(HloComputation * computation,std::string domain_kind)31 /* static */ StatusOr<std::unique_ptr<HloDomainMap>> HloDomainMap::Create(
32     HloComputation* computation, std::string domain_kind) {
33   auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind)));
34   TF_RETURN_IF_ERROR(domain_map->Populate(computation));
35   return std::move(domain_map);
36 }
37 
Create(HloModule * module,std::string domain_kind)38 /* static */ StatusOr<std::unique_ptr<HloDomainMap>> HloDomainMap::Create(
39     HloModule* module, std::string domain_kind) {
40   auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind)));
41   for (HloComputation* computation : module->computations()) {
42     TF_RETURN_IF_ERROR(domain_map->Populate(computation));
43   }
44   return std::move(domain_map);
45 }
46 
InSameDomain(const HloInstruction * instruction1,const HloInstruction * instruction2) const47 bool HloDomainMap::InSameDomain(const HloInstruction* instruction1,
48                                 const HloInstruction* instruction2) const {
49   int64_t domain_id1 = GetDomainId(instruction1);
50   int64_t domain_id2 = GetDomainId(instruction2);
51   return domain_id1 >= 0 && domain_id1 == domain_id2;
52 }
53 
GetDomainId(const HloInstruction * instruction) const54 int64_t HloDomainMap::GetDomainId(const HloInstruction* instruction) const {
55   return FindOrDefault(instruction_to_domain_, instruction, -1);
56 }
57 
GetDomainMetadataId(const HloInstruction * instruction) const58 int64_t HloDomainMap::GetDomainMetadataId(
59     const HloInstruction* instruction) const {
60   return FindOrDie(domain_metadata_id_, instruction);
61 }
62 
TryProcessEmptyDomain(HloInstruction * instruction)63 Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
64   TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain);
65   // We only check operands, so we are sure to not process the empty domain from
66   // both sides.
67   for (HloInstruction* operand : instruction->unique_operands()) {
68     if (IsDomainInstruction(operand)) {
69       auto domain = std::make_unique<DomainMetadata::Domain>();
70       domain->enter_domains.insert(operand);
71       domain->exit_domains.insert(instruction);
72       TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
73     }
74   }
75   if (instruction == instruction->parent()->root_instruction()) {
76     auto domain = std::make_unique<DomainMetadata::Domain>();
77     domain->enter_domains.insert(instruction);
78     TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
79   }
80   return OkStatus();
81 }
82 
Populate(HloComputation * computation)83 Status HloDomainMap::Populate(HloComputation* computation) {
84   InstructionOrderMap instructions_post_order;
85   int64_t count = 0;
86   for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
87     instructions_post_order.insert(std::make_pair(instruction, count++));
88   }
89   for (HloInstruction* instruction : computation->instructions()) {
90     if (IsDomainInstruction(instruction)) {
91       // If this is a kDomain of the kind we are currently processing, check
92       // whether this is an "empty domain".
93       TF_RETURN_IF_ERROR(TryProcessEmptyDomain(instruction));
94       continue;
95     }
96     int64_t domain_id = FindOrDefault(instruction_to_domain_, instruction, -1);
97     if (domain_id >= 0) {
98       // We have already processed this instruction.
99       continue;
100     }
101     TF_ASSIGN_OR_RETURN(std::unique_ptr<DomainMetadata::Domain> domain,
102                         CreateDomain(instruction, instructions_post_order));
103     TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
104   }
105   TF_RETURN_IF_ERROR(PopulateDomainMetadataMap());
106   return OkStatus();
107 }
108 
PopulateDomainMetadataMap()109 Status HloDomainMap::PopulateDomainMetadataMap() {
110   auto hash = [](const DomainMetadata* m) { return m->Hash(); };
111   auto equal = [](const DomainMetadata* a, const DomainMetadata* b) {
112     return a->Matches(*b);
113   };
114   absl::flat_hash_map<const DomainMetadata*, int64_t, decltype(hash),
115                       decltype(equal)>
116       domain_metadata(1024, hash, equal);
117 
118   for (auto& domain : instruction_domains_) {
119     int64_t domain_metadata_id = -1;
120     if (!domain->enter_domains.empty()) {
121       const HloInstruction* domain_instruction = *domain->enter_domains.begin();
122       domain_metadata_id =
123           domain_metadata
124               .insert({&domain_instruction->user_side_metadata(),
125                        domain_metadata.size() + 1})
126               .first->second;
127     } else if (!domain->exit_domains.empty()) {
128       const HloInstruction* domain_instruction = *domain->exit_domains.begin();
129       domain_metadata_id =
130           domain_metadata
131               .insert({&domain_instruction->operand_side_metadata(),
132                        domain_metadata.size() + 1})
133               .first->second;
134     } else {
135       domain_metadata_id = 0;
136     }
137     TF_RET_CHECK(domain_metadata_id >= 0);
138     for (HloInstruction* instruction : domain->instructions) {
139       domain_metadata_id_[instruction] = domain_metadata_id;
140     }
141   }
142   return OkStatus();
143 }
144 
InsertDomain(std::unique_ptr<DomainMetadata::Domain> domain)145 Status HloDomainMap::InsertDomain(
146     std::unique_ptr<DomainMetadata::Domain> domain) {
147   int64_t domain_id = instruction_domains_.size();
148   instruction_domains_.push_back(std::move(domain));
149   for (HloInstruction* instruction : instruction_domains_.back()->reach_set) {
150     instruction_to_domain_[instruction] = domain_id;
151   }
152   return OkStatus();
153 }
154 
ExpandDomain(HloInstruction * instruction,DomainMetadata::Domain * domain) const155 Status HloDomainMap::ExpandDomain(HloInstruction* instruction,
156                                   DomainMetadata::Domain* domain) const {
157   std::vector<HloInstruction*> in_queue;
158   in_queue.push_back(instruction);
159   while (!in_queue.empty()) {
160     HloInstruction* current_instruction = in_queue.back();
161     in_queue.pop_back();
162     if (domain->reach_set.insert(current_instruction).second) {
163       // We should not be finding instructions with assigned domain here.
164       // If we assigned a domain to the instruction, it means that all the
165       // instructions reached by it, should have a domain as well.
166       int64_t domain_id =
167           FindOrDefault(instruction_to_domain_, current_instruction, -1);
168       TF_RET_CHECK(domain_id < 0)
169           << "Instruction " << current_instruction->ToString()
170           << " already has domain " << domain_id;
171       for (HloInstruction* operand : current_instruction->operands()) {
172         if (IsDomainInstruction(operand)) {
173           // The reach set instruction is a user of the domain instruction
174           // (the instruction sees the kDomain as operand).
175           // IOW the dataflow enters the domain through the kDomain instruction.
176           domain->enter_domains.insert(operand);
177         } else {
178           in_queue.push_back(operand);
179         }
180       }
181       for (HloInstruction* user : current_instruction->users()) {
182         if (IsDomainInstruction(user)) {
183           // The reach set instruction is an operand of the domain instruction
184           // (the instruction sees the kDomain as user).
185           // IOW the dataflow exits the domain through the kDomain instruction.
186           domain->exit_domains.insert(user);
187         } else {
188           in_queue.push_back(user);
189         }
190       }
191     }
192   }
193   return OkStatus();
194 }
195 
CreateDomain(HloInstruction * instruction,const InstructionOrderMap & instructions_order) const196 StatusOr<std::unique_ptr<DomainMetadata::Domain>> HloDomainMap::CreateDomain(
197     HloInstruction* instruction,
198     const InstructionOrderMap& instructions_order) const {
199   auto domain = std::make_unique<DomainMetadata::Domain>();
200   TF_RETURN_IF_ERROR(ExpandDomain(instruction, domain.get()));
201   domain->instructions =
202       MakeNonDomainInstructions(domain->reach_set, instructions_order);
203   return std::move(domain);
204 }
205 
IsDomainInstruction(const HloInstruction * instruction) const206 bool HloDomainMap::IsDomainInstruction(
207     const HloInstruction* instruction) const {
208   if (instruction->opcode() != HloOpcode::kDomain) {
209     return false;
210   }
211   if (!domain_kind_.empty()) {
212     if (instruction->user_side_metadata().Kind() != domain_kind_) {
213       return false;
214     }
215     // Both user and operand side of the metadata must be of the same kind.
216     CHECK(instruction->operand_side_metadata().Kind() == domain_kind_)
217         << "Instruction " << instruction->ToString()
218         << " has mismatching metadata kinds";
219   }
220   return true;
221 }
222 
223 /* static */ std::vector<HloInstruction*>
MakeNonDomainInstructions(const absl::flat_hash_set<HloInstruction * > & instruction_set,const InstructionOrderMap & instructions_order)224 HloDomainMap::MakeNonDomainInstructions(
225     const absl::flat_hash_set<HloInstruction*>& instruction_set,
226     const InstructionOrderMap& instructions_order) {
227   std::vector<HloInstruction*> instructions;
228   instructions.reserve(instruction_set.size());
229   for (HloInstruction* instruction : instruction_set) {
230     if (instruction->opcode() != HloOpcode::kDomain) {
231       instructions.push_back(instruction);
232     }
233   }
234   // sort instructions according to instructions_order
235   absl::c_sort(instructions,
236                [&instructions_order](HloInstruction* a, HloInstruction* b) {
237                  return instructions_order.at(a) < instructions_order.at(b);
238                });
239   return instructions;
240 }
241 
242 }  // namespace xla
243