xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tools/hlo_extractor.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/tools/hlo_extractor.h"
17 
18 #include <stdio.h>
19 #include <unistd.h>
20 
21 #include <memory>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
26 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
29 #include "tensorflow/compiler/xla/status.h"
30 
31 namespace xla {
32 namespace {
33 
34 // Visitor that build a new HLO module with an entry computation and a root that
35 // is provided to the visit function. Only HLOs that are reachable from the new
36 // root instruction are included in the new module.
37 //
38 // The constructor allows specifying a set of boundary HLOs to prune the HLO
39 // graph. HLOs at the boundary are replaced with parameters. Can be nullptr
40 // which means no boundary, i.e. no HLOs are replaced with parameters.
41 class ExtractionVisitor : public ConstDfsHloVisitorWithDefault {
42  public:
ExtractionVisitor(const HloModule & old_module,absl::flat_hash_set<const HloInstruction * > * boundary)43   explicit ExtractionVisitor(
44       const HloModule& old_module,
45       absl::flat_hash_set<const HloInstruction*>* boundary)
46       : old_module_(old_module),
47         module_(std::make_unique<HloModule>("extracted", config_)),
48         clone_context_(module_.get()),
49         builder_("entry_computation"),
50         boundary_(boundary) {}
51 
HandleParameter(const HloInstruction * parameter)52   Status HandleParameter(const HloInstruction* parameter) override {
53     // Entry parameters need renumbering.
54     auto new_parameter = HloInstruction::CreateParameter(
55         parameter_number_++, parameter->shape(), parameter->name());
56     clone_context_.MapInstruction(parameter, new_parameter.get());
57     builder_.AddInstruction(std::move(new_parameter));
58     return OkStatus();
59   }
60 
DefaultAction(const HloInstruction * hlo)61   Status DefaultAction(const HloInstruction* hlo) override {
62     // Replace instructions at the boundary with parameters, but leave constants
63     // untouched.
64     if (boundary_ != nullptr && boundary_->count(hlo) > 0) {
65       auto new_parameter = HloInstruction::CreateParameter(
66           parameter_number_, hlo->shape(), hlo->name());
67       parameter_number_++;
68       clone_context_.MapInstruction(hlo, new_parameter.get());
69       builder_.AddInstruction(std::move(new_parameter));
70       return OkStatus();
71     }
72     std::vector<HloInstruction*> new_operands;
73     for (auto operand : hlo->operands()) {
74       new_operands.push_back(clone_context_.GetInstruction(operand));
75     }
76     auto instruction =
77         hlo->CloneWithNewOperands(hlo->shape(), new_operands, &clone_context_);
78     builder_.AddInstruction(std::move(instruction));
79     return OkStatus();
80   }
81 
FinishVisit(const HloInstruction *)82   Status FinishVisit(const HloInstruction* /*root*/) override {
83     module_->AddEntryComputation(builder_.Build());
84     // Rename HLOs so that their name matches the original. By default,
85     // HLOs get new unique names when adding a new entry computation to
86     // a module.
87     for (auto computation : old_module_.MakeComputationPostOrder()) {
88       for (auto old_instruction : computation->MakeInstructionPostOrder()) {
89         if (auto new_instruction =
90                 clone_context_.FindInstruction(old_instruction)) {
91           new_instruction->SetAndSanitizeName(old_instruction->name());
92         }
93       }
94     }
95     return OkStatus();
96   }
97 
module()98   HloModule* module() { return module_.get(); }
99 
ConsumeModule()100   std::unique_ptr<HloModule> ConsumeModule() { return std::move(module_); }
101 
102  private:
103   const HloModule& old_module_;
104   HloModuleConfig config_;
105   std::unique_ptr<HloModule> module_;
106   HloCloneContext clone_context_;
107   HloComputation::Builder builder_;
108   absl::flat_hash_set<const HloInstruction*>* boundary_;
109   int64_t parameter_number_ = 0;
110 };
111 
ComputeBoundary(const HloInstruction * root,int64_t limit,absl::flat_hash_set<const HloInstruction * > * boundary)112 void ComputeBoundary(const HloInstruction* root, int64_t limit,
113                      absl::flat_hash_set<const HloInstruction*>* boundary) {
114   std::deque<const HloInstruction*> worklist;
115   absl::flat_hash_map<const HloInstruction*, int64_t> visited;
116   worklist.push_back(root);
117   visited.emplace(root, 0);
118   while (!worklist.empty()) {
119     auto hlo = worklist.front();
120     worklist.pop_front();
121     int64_t hops = visited[hlo];
122     if (hops > limit) {
123       boundary->insert(hlo);
124       continue;
125     }
126     for (const HloInstruction* operand : hlo->operands()) {
127       if (visited.count(operand)) {
128         continue;
129       }
130       worklist.push_back(operand);
131       visited.emplace(operand, hops + 1);
132     }
133   }
134 }
135 
136 }  // namespace
137 
ExtractModule(HloInstruction * instruction,int64_t height)138 std::unique_ptr<HloModule> ExtractModule(HloInstruction* instruction,
139                                          int64_t height) {
140   absl::flat_hash_set<const HloInstruction*> boundary;
141   if (height != -1) {
142     ComputeBoundary(instruction, height, &boundary);
143   }
144   ExtractionVisitor visitor(*instruction->GetModule(), &boundary);
145   CHECK(instruction->Accept(&visitor).ok());
146 
147   // The first pass may leave unused parameter instructions. Do another
148   // extraction pass to remove unused parameters. This is done because
149   // HloComputation does not allow removing parameters after the computation has
150   // been built.
151   ExtractionVisitor cleanup_visitor(*visitor.module(), /*boundary=*/nullptr);
152   TF_CHECK_OK(visitor.module()->entry_computation()->root_instruction()->Accept(
153       &cleanup_visitor));
154 
155   HloVerifier verifier(/*layout_sensitive=*/false,
156                        /*allow_mixed_precision=*/true);
157   TF_CHECK_OK(verifier.Run(cleanup_visitor.module()).status());
158   return cleanup_visitor.ConsumeModule();
159 }
160 
161 }  // namespace xla
162