xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
17 
18 #include <functional>
19 #include <string>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/compiler/xla/service/dump.h"
26 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
27 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/core/platform/logging.h"
33 
34 namespace xla {
35 
36 namespace {
37 
RecordPassStartMetadata(HloModule & module,const std::string & pass_name,const std::string & pipeline_name)38 void RecordPassStartMetadata(HloModule& module, const std::string& pass_name,
39                              const std::string& pipeline_name) {
40   module.metadata()->RecordPassStart();
41   // An HloPassMetadata was just created so Status should always be OK.
42   TF_CHECK_OK(module.metadata()->set_current_pass_name(pass_name));
43   TF_CHECK_OK(module.metadata()->set_current_pass_pipeline_name(pipeline_name));
44 }
45 
RecordPassStartMetadata(HloModuleGroup & module_group,const std::string & pass_name,const std::string & pipeline_name)46 void RecordPassStartMetadata(HloModuleGroup& module_group,
47                              const std::string& pass_name,
48                              const std::string& pipeline_name) {
49   for (HloModule* module : module_group.modules()) {
50     RecordPassStartMetadata(*module, pass_name, pipeline_name);
51   }
52 }
53 
AttemptRecordPassEndMetadata(HloModule & module,const std::string & pass_name,bool module_changed)54 Status AttemptRecordPassEndMetadata(HloModule& module,
55                                     const std::string& pass_name,
56                                     bool module_changed) {
57   // Module id is set here instead of RecordPassStartMetadata because it may
58   // change in the middle of the pass, and we want the final id.
59   TF_RETURN_IF_ERROR(
60       module.metadata()->set_current_pass_module_id(module.unique_id()));
61   TF_RETURN_IF_ERROR(
62       module.metadata()->set_current_pass_module_changed(module_changed));
63   TF_RETURN_IF_ERROR(module.metadata()->RecordPassEnd());
64   return OkStatus();
65 }
66 
RecordPassEndMetadata(HloModule & module,const std::string & pass_name,bool module_changed)67 void RecordPassEndMetadata(HloModule& module, const std::string& pass_name,
68                            bool module_changed) {
69   Status status =
70       AttemptRecordPassEndMetadata(module, pass_name, module_changed);
71   if (!status.ok()) {
72     LOG(FATAL) << status;
73   }
74 }
75 
AttemptRecordPassEndMetadata(HloModuleGroup & module_group,const std::string & pass_name,bool module_changed)76 Status AttemptRecordPassEndMetadata(HloModuleGroup& module_group,
77                                     const std::string& pass_name,
78                                     bool module_changed) {
79   for (HloModule* module : module_group.modules()) {
80     for (HloModule* other_module : module_group.modules()) {
81       TF_RETURN_IF_ERROR(
82           module->metadata()->add_current_pass_module_group_module_id(
83               other_module->unique_id()));
84     }
85     TF_RETURN_IF_ERROR(
86         AttemptRecordPassEndMetadata(*module, pass_name, module_changed));
87   }
88   return OkStatus();
89 }
90 
RecordPassEndMetadata(HloModuleGroup & module_group,const std::string & pass_name,bool module_changed)91 void RecordPassEndMetadata(HloModuleGroup& module_group,
92                            const std::string& pass_name, bool module_changed) {
93   Status status =
94       AttemptRecordPassEndMetadata(module_group, pass_name, module_changed);
95   if (!status.ok()) {
96     LOG(FATAL) << status;
97   }
98 }
99 
SetInstructionMetadata(HloModule & module)100 void SetInstructionMetadata(HloModule& module) {
101   StatusOr<int64_t> pass_id = module.metadata()->current_pass_id();
102   if (!pass_id.ok()) {
103     LOG(FATAL) << pass_id.status();
104   }
105   for (xla::HloComputation* computation : module.computations()) {
106     for (xla::HloInstruction* instruction : computation->instructions()) {
107       if (instruction->metadata().creation_pass_id() == 0) {
108         instruction->set_creation_pass_id(*pass_id);
109       }
110       if (instruction->metadata().logical_creation_pass_id() == 0) {
111         instruction->set_logical_creation_pass_id(*pass_id);
112       }
113     }
114   }
115 }
116 
SetInstructionMetadata(HloModuleGroup & module_group)117 void SetInstructionMetadata(HloModuleGroup& module_group) {
118   for (HloModule* module : module_group.modules()) {
119     SetInstructionMetadata(*module);
120   }
121 }
122 
123 }  // namespace
124 
125 template <typename HloT>
RunInvariantCheckers(HloT * hlo,absl::string_view after_pass_name,const absl::flat_hash_set<absl::string_view> & execution_threads)126 Status HloPassPipeline::RunInvariantCheckers(
127     HloT* hlo, absl::string_view after_pass_name,
128     const absl::flat_hash_set<absl::string_view>& execution_threads) {
129   for (auto& invariant_checker : invariant_checkers_) {
130     VLOG(1) << "    Invariant checker " << invariant_checker->name();
131     StatusOr<bool> changed_status =
132         RunHelper(invariant_checker.get(), hlo, execution_threads);
133     VLOG(1) << "    Invariant checker done " << invariant_checker->name();
134     if (!changed_status.ok()) {
135       VLOG(2) << "Failed invariant check:";
136       XLA_VLOG_LINES(2, hlo->ToString());
137       return tensorflow::errors::CreateWithUpdatedMessage(
138           changed_status.status(),
139           absl::StrCat(changed_status.status().error_message(),
140                        "\n\nFailed after ", after_pass_name));
141     }
142     TF_RET_CHECK(!changed_status.ValueOrDie())
143         << "invariant checkers must not change the graph";
144   }
145   return OkStatus();
146 }
147 
148 template <typename HloT>
RunPassesInternal(HloT * hlo,const DebugOptions & debug_options,const absl::flat_hash_set<absl::string_view> & execution_threads)149 StatusOr<bool> HloPassPipeline::RunPassesInternal(
150     HloT* hlo, const DebugOptions& debug_options,
151     const absl::flat_hash_set<absl::string_view>& execution_threads) {
152   auto passes = GetEnabledPasses(debug_options);
153   // Copy string by value since debug options could get clobbered in an hlo
154   // module group pass.
155   std::string dump_regex = debug_options.xla_dump_hlo_pass_re();
156   static constexpr absl::string_view kPipelineStart = "pipeline-start";
157   static constexpr absl::string_view kPipelineEnd = "pipeline-end";
158   std::string pipeline_name = std::string(name());
159 
160   TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, kPipelineStart));
161 
162   RecordPassStartMetadata(*hlo, std::string(kPipelineStart), pipeline_name);
163   SetInstructionMetadata(*hlo);
164   MaybeDumpHloAndSaveFilenames(*hlo,
165                                /*after_pass_name=*/kPipelineStart,
166                                /*before_pass_name=*/passes.empty()
167                                    ? kPipelineEnd
168                                    : passes.front()->name());
169   RecordPassEndMetadata(*hlo, std::string(kPipelineStart),
170                         /*module_changed=*/false);
171 
172   bool changed = false;
173   for (int i = 0; i < passes.size(); i++) {
174     HloPassInterface* pass = passes[i];
175     XLA_SCOPED_LOGGING_TIMER(absl::StrCat("HLO pass: ", pass->name()));
176     std::string pass_name = std::string(pass->name());
177     VLOG(1) << "  HLO pass " << pass_name;
178     VLOG(2) << "  Module hash " << absl::HashOf(*hlo);
179     if (!pass->IsPassPipeline()) {
180       compilation_stats_->StartPass(pass_name);
181     }
182     RecordPassStartMetadata(*hlo, pass_name, pipeline_name);
183     TF_ASSIGN_OR_RETURN(bool pass_changed,
184                         RunHelper(pass, hlo, execution_threads));
185     SetInstructionMetadata(*hlo);
186     if (!dump_regex.empty() && (pass_changed || dump_regex != ".*")) {
187       MaybeDumpHloAndSaveFilenames(*hlo,
188                                    /*after_pass_name=*/pass_name,
189                                    /*before_pass_name=*/i + 1 >= passes.size()
190                                        ? kPipelineEnd
191                                        : passes[i + 1]->name());
192     }
193     RecordPassEndMetadata(*hlo, pass_name, pass_changed);
194     changed |= pass_changed;
195     if (pass_changed) {
196       VLOG(3) << "  Pass caused changes " << pass->name();
197     }
198     TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass_name));
199     if (!pass->IsPassPipeline()) {
200       compilation_stats_->EndPass(pass_name);
201     }
202   }
203   return changed;
204 }
205 
GetEnabledPasses(const DebugOptions & debug_options)206 std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses(
207     const DebugOptions& debug_options) {
208   if (debug_options.xla_disable_all_hlo_passes()) {
209     VLOG(1) << "*All* passes disabled by --xla_disable_all_hlo_passes.";
210     return {};
211   }
212 
213   absl::flat_hash_set<std::string> disabled_pass_names(
214       debug_options.xla_disable_hlo_passes().begin(),
215       debug_options.xla_disable_hlo_passes().end());
216 
217   absl::flat_hash_set<std::string> enabled_pass_names(
218       debug_options.xla_enable_hlo_passes_only().begin(),
219       debug_options.xla_enable_hlo_passes_only().end());
220 
221   if (!disabled_pass_names.empty()) {
222     VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
223             << absl::StrJoin(disabled_pass_names, ", ");
224   }
225 
226   if (!enabled_pass_names.empty()) {
227     VLOG(1) << "Passes enabled by --xla_enable_hlo_passes_only: "
228             << absl::StrJoin(enabled_pass_names, ", ");
229   }
230 
231   CHECK(disabled_pass_names.empty() || enabled_pass_names.empty());
232 
233   std::vector<HloPassInterface*> enabled_passes;
234   if (!enabled_pass_names.empty()) {
235     for (auto& pass : passes_) {
236       if (enabled_pass_names.contains(pass->name())) {
237         enabled_passes.push_back(pass.get());
238       }
239     }
240   } else {
241     for (auto& pass : passes_) {
242       if (!disabled_pass_names.contains(pass->name())) {
243         enabled_passes.push_back(pass.get());
244       }
245     }
246   }
247   return enabled_passes;
248 }
249 
MaybeDumpHloAndSaveFilenames(HloModule & module,absl::string_view after_pass_name,absl::string_view before_pass_name)250 void HloPassPipeline::MaybeDumpHloAndSaveFilenames(
251     HloModule& module, absl::string_view after_pass_name,
252     absl::string_view before_pass_name) {
253   for (const std::string& filename : DumpHloModuleBetweenPassesIfEnabled(
254            name(), before_pass_name, after_pass_name, module)) {
255     Status status = module.metadata()->add_current_pass_dump_filename(filename);
256     if (!status.ok()) {
257       LOG(FATAL) << status;
258     }
259   }
260 }
261 
MaybeDumpHloAndSaveFilenames(HloModuleGroup & module_group,absl::string_view after_pass_name,absl::string_view before_pass_name)262 void HloPassPipeline::MaybeDumpHloAndSaveFilenames(
263     HloModuleGroup& module_group, absl::string_view after_pass_name,
264     absl::string_view before_pass_name) {
265   for (HloModule* module : module_group.modules()) {
266     MaybeDumpHloAndSaveFilenames(*module, after_pass_name, before_pass_name);
267   }
268 }
269 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)270 StatusOr<bool> HloPassPipeline::Run(
271     HloModule* module,
272     const absl::flat_hash_set<absl::string_view>& execution_threads) {
273   run_called_ = true;
274 
275   VLOG(1) << "Running HLO pass pipeline on module " << module->name() << ": "
276           << name();
277 
278   return RunPassesInternal(module, module->config().debug_options(),
279                            execution_threads);
280 }
281 
RunOnModuleGroup(HloModuleGroup * module_group,const absl::flat_hash_set<absl::string_view> & execution_threads)282 StatusOr<bool> HloPassPipeline::RunOnModuleGroup(
283     HloModuleGroup* module_group,
284     const absl::flat_hash_set<absl::string_view>& execution_threads) {
285   run_called_ = true;
286 
287   VLOG(1) << "Running HLO pass pipeline on module group "
288           << module_group->name() << ": " << name();
289 
290   if (module_group->modules().empty()) {
291     VLOG(1) << "Module group is empty. Nothing to do.";
292     return false;
293   }
294 
295   return RunPassesInternal(module_group,
296                            module_group->module(0).config().debug_options(),
297                            execution_threads);
298 }
299 
300 }  // namespace xla
301