xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_domain_map.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/core/lib/core/status.h"
30 
31 namespace xla {
32 
33 // The HloDomainMap splits a set of instructions within a module or computation,
34 // into different domains, separated by kDomain instructions.
35 // A domain is composed by a set of instructions which can reach each other via
36 // operand/user edges, without crossing a kDomain insutrction of a given kind.
37 // A domain never crosses computation boundaries.
38 class HloDomainMap {
39  public:
40   // Creates a new HloDomainMap, creating all the domains within the input
41   // computation, of the given kind. If domain_kind is not empty, only the
42   // kDomain instructions of domain_kind will be considered as separators.
43   // Otherwise every kDomain instruction will be splitting domains.
44   static StatusOr<std::unique_ptr<HloDomainMap>> Create(
45       HloComputation* computation, std::string domain_kind);
46 
47   // Creates a new HloDomainMap, creating all the domains within the input
48   // module, of the given kind. If domain_kind is not empty, only the
49   // kDomain instructions of domain_kind will be considered as separators.
50   // Otherwise every kDomain instruction will be splitting domains.
51   static StatusOr<std::unique_ptr<HloDomainMap>> Create(
52       HloModule* module, std::string domain_kind);
53 
54   // Retrieves all the domains the input module or computation are composed by.
GetDomains()55   const std::vector<std::unique_ptr<DomainMetadata::Domain>>& GetDomains()
56       const {
57     return instruction_domains_;
58   }
59 
60   // Checks whether two instructions are within the same domain.
61   bool InSameDomain(const HloInstruction* instruction1,
62                     const HloInstruction* instruction2) const;
63 
64   // Checks whether instruction is a kDomain instruction of the kind we are
65   // currently processing.
66   bool IsDomainInstruction(const HloInstruction* instruction) const;
67 
68   // Retrieves the domain identifier of the instruction, or -1 in case
69   // instruction is not found within any domain.
70   int64_t GetDomainId(const HloInstruction* instruction) const;
71 
72   // Returns the unique id of the domain metadata for the domain the given
73   // instruction belongs to. The given instruction must not be a kDomain
74   // instruction since each domain instruction is associated with 2 domains.
75   int64_t GetDomainMetadataId(const HloInstruction* instruction) const;
76 
77  private:
78   // Map used for representing instruction ordering, i.e.
79   // order_map[a] < order_map[b] means a must be ordered before b.
80   using InstructionOrderMap =
81       absl::flat_hash_map<const HloInstruction*, int64_t>;
82 
HloDomainMap(std::string domain_kind)83   HloDomainMap(std::string domain_kind)
84       : domain_kind_(std::move(domain_kind)) {}
85 
86   // Check if the kDomain instruction is facing (via its operand link) another
87   // kDomain instruction of the same kind, hence defining an empty domain.
88   // If that is the case, create the empty domain and call the proper
89   // normalizer.
90   Status TryProcessEmptyDomain(HloInstruction* instruction);
91 
92   Status Populate(HloComputation* computation);
93 
94   // Inserts the provided domain into the ones tracked by this object,
95   // creating a new domain ID.
96   Status InsertDomain(std::unique_ptr<DomainMetadata::Domain> domain);
97 
98   // From the given instruction, expands operand and user wise, the set of
99   // instructions which can be reached without crossing a kDomain instruction
100   // of the kind specified by domain_kind_.
101   // The domain data structure will be populated with all the reached
102   // instructions, and the boundaries of the domain, with the kDomain
103   // instructions encountered while expanding the reach.
104   Status ExpandDomain(HloInstruction* instruction,
105                       DomainMetadata::Domain* domain) const;
106 
107   // Creates a domain data structure using the ExpandDomain() API.
108   StatusOr<std::unique_ptr<DomainMetadata::Domain>> CreateDomain(
109       HloInstruction* instruction,
110       const InstructionOrderMap& instructions_order) const;
111 
112   // Out of an instruction set, returns a vector of all the ones which are not
113   // a kDomain kind.
114   static std::vector<HloInstruction*> MakeNonDomainInstructions(
115       const absl::flat_hash_set<HloInstruction*>& instruction_set,
116       const InstructionOrderMap& instructions_order);
117 
118   // Populates domain_metadata_id_ that maps each HloInstruction to the unique
119   // ID of its associated domain metatadata.
120   Status PopulateDomainMetadataMap();
121 
122   std::string domain_kind_;
123   std::vector<std::unique_ptr<DomainMetadata::Domain>> instruction_domains_;
124   absl::flat_hash_map<const HloInstruction*, int64_t> instruction_to_domain_;
125   absl::flat_hash_map<const HloInstruction*, int64_t> domain_metadata_id_;
126 };
127 
128 }  // namespace xla
129 
130 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_
131