1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
17
18 #include <functional>
19 #include <string>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/compiler/xla/service/dump.h"
26 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
27 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/core/platform/logging.h"
33
34 namespace xla {
35
36 namespace {
37
RecordPassStartMetadata(HloModule & module,const std::string & pass_name,const std::string & pipeline_name)38 void RecordPassStartMetadata(HloModule& module, const std::string& pass_name,
39 const std::string& pipeline_name) {
40 module.metadata()->RecordPassStart();
41 // An HloPassMetadata was just created so Status should always be OK.
42 TF_CHECK_OK(module.metadata()->set_current_pass_name(pass_name));
43 TF_CHECK_OK(module.metadata()->set_current_pass_pipeline_name(pipeline_name));
44 }
45
RecordPassStartMetadata(HloModuleGroup & module_group,const std::string & pass_name,const std::string & pipeline_name)46 void RecordPassStartMetadata(HloModuleGroup& module_group,
47 const std::string& pass_name,
48 const std::string& pipeline_name) {
49 for (HloModule* module : module_group.modules()) {
50 RecordPassStartMetadata(*module, pass_name, pipeline_name);
51 }
52 }
53
AttemptRecordPassEndMetadata(HloModule & module,const std::string & pass_name,bool module_changed)54 Status AttemptRecordPassEndMetadata(HloModule& module,
55 const std::string& pass_name,
56 bool module_changed) {
57 // Module id is set here instead of RecordPassStartMetadata because it may
58 // change in the middle of the pass, and we want the final id.
59 TF_RETURN_IF_ERROR(
60 module.metadata()->set_current_pass_module_id(module.unique_id()));
61 TF_RETURN_IF_ERROR(
62 module.metadata()->set_current_pass_module_changed(module_changed));
63 TF_RETURN_IF_ERROR(module.metadata()->RecordPassEnd());
64 return OkStatus();
65 }
66
RecordPassEndMetadata(HloModule & module,const std::string & pass_name,bool module_changed)67 void RecordPassEndMetadata(HloModule& module, const std::string& pass_name,
68 bool module_changed) {
69 Status status =
70 AttemptRecordPassEndMetadata(module, pass_name, module_changed);
71 if (!status.ok()) {
72 LOG(FATAL) << status;
73 }
74 }
75
AttemptRecordPassEndMetadata(HloModuleGroup & module_group,const std::string & pass_name,bool module_changed)76 Status AttemptRecordPassEndMetadata(HloModuleGroup& module_group,
77 const std::string& pass_name,
78 bool module_changed) {
79 for (HloModule* module : module_group.modules()) {
80 for (HloModule* other_module : module_group.modules()) {
81 TF_RETURN_IF_ERROR(
82 module->metadata()->add_current_pass_module_group_module_id(
83 other_module->unique_id()));
84 }
85 TF_RETURN_IF_ERROR(
86 AttemptRecordPassEndMetadata(*module, pass_name, module_changed));
87 }
88 return OkStatus();
89 }
90
RecordPassEndMetadata(HloModuleGroup & module_group,const std::string & pass_name,bool module_changed)91 void RecordPassEndMetadata(HloModuleGroup& module_group,
92 const std::string& pass_name, bool module_changed) {
93 Status status =
94 AttemptRecordPassEndMetadata(module_group, pass_name, module_changed);
95 if (!status.ok()) {
96 LOG(FATAL) << status;
97 }
98 }
99
SetInstructionMetadata(HloModule & module)100 void SetInstructionMetadata(HloModule& module) {
101 StatusOr<int64_t> pass_id = module.metadata()->current_pass_id();
102 if (!pass_id.ok()) {
103 LOG(FATAL) << pass_id.status();
104 }
105 for (xla::HloComputation* computation : module.computations()) {
106 for (xla::HloInstruction* instruction : computation->instructions()) {
107 if (instruction->metadata().creation_pass_id() == 0) {
108 instruction->set_creation_pass_id(*pass_id);
109 }
110 if (instruction->metadata().logical_creation_pass_id() == 0) {
111 instruction->set_logical_creation_pass_id(*pass_id);
112 }
113 }
114 }
115 }
116
SetInstructionMetadata(HloModuleGroup & module_group)117 void SetInstructionMetadata(HloModuleGroup& module_group) {
118 for (HloModule* module : module_group.modules()) {
119 SetInstructionMetadata(*module);
120 }
121 }
122
123 } // namespace
124
125 template <typename HloT>
RunInvariantCheckers(HloT * hlo,absl::string_view after_pass_name,const absl::flat_hash_set<absl::string_view> & execution_threads)126 Status HloPassPipeline::RunInvariantCheckers(
127 HloT* hlo, absl::string_view after_pass_name,
128 const absl::flat_hash_set<absl::string_view>& execution_threads) {
129 for (auto& invariant_checker : invariant_checkers_) {
130 VLOG(1) << " Invariant checker " << invariant_checker->name();
131 StatusOr<bool> changed_status =
132 RunHelper(invariant_checker.get(), hlo, execution_threads);
133 VLOG(1) << " Invariant checker done " << invariant_checker->name();
134 if (!changed_status.ok()) {
135 VLOG(2) << "Failed invariant check:";
136 XLA_VLOG_LINES(2, hlo->ToString());
137 return tensorflow::errors::CreateWithUpdatedMessage(
138 changed_status.status(),
139 absl::StrCat(changed_status.status().error_message(),
140 "\n\nFailed after ", after_pass_name));
141 }
142 TF_RET_CHECK(!changed_status.ValueOrDie())
143 << "invariant checkers must not change the graph";
144 }
145 return OkStatus();
146 }
147
148 template <typename HloT>
RunPassesInternal(HloT * hlo,const DebugOptions & debug_options,const absl::flat_hash_set<absl::string_view> & execution_threads)149 StatusOr<bool> HloPassPipeline::RunPassesInternal(
150 HloT* hlo, const DebugOptions& debug_options,
151 const absl::flat_hash_set<absl::string_view>& execution_threads) {
152 auto passes = GetEnabledPasses(debug_options);
153 // Copy string by value since debug options could get clobbered in an hlo
154 // module group pass.
155 std::string dump_regex = debug_options.xla_dump_hlo_pass_re();
156 static constexpr absl::string_view kPipelineStart = "pipeline-start";
157 static constexpr absl::string_view kPipelineEnd = "pipeline-end";
158 std::string pipeline_name = std::string(name());
159
160 TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, kPipelineStart));
161
162 RecordPassStartMetadata(*hlo, std::string(kPipelineStart), pipeline_name);
163 SetInstructionMetadata(*hlo);
164 MaybeDumpHloAndSaveFilenames(*hlo,
165 /*after_pass_name=*/kPipelineStart,
166 /*before_pass_name=*/passes.empty()
167 ? kPipelineEnd
168 : passes.front()->name());
169 RecordPassEndMetadata(*hlo, std::string(kPipelineStart),
170 /*module_changed=*/false);
171
172 bool changed = false;
173 for (int i = 0; i < passes.size(); i++) {
174 HloPassInterface* pass = passes[i];
175 XLA_SCOPED_LOGGING_TIMER(absl::StrCat("HLO pass: ", pass->name()));
176 std::string pass_name = std::string(pass->name());
177 VLOG(1) << " HLO pass " << pass_name;
178 VLOG(2) << " Module hash " << absl::HashOf(*hlo);
179 if (!pass->IsPassPipeline()) {
180 compilation_stats_->StartPass(pass_name);
181 }
182 RecordPassStartMetadata(*hlo, pass_name, pipeline_name);
183 TF_ASSIGN_OR_RETURN(bool pass_changed,
184 RunHelper(pass, hlo, execution_threads));
185 SetInstructionMetadata(*hlo);
186 if (!dump_regex.empty() && (pass_changed || dump_regex != ".*")) {
187 MaybeDumpHloAndSaveFilenames(*hlo,
188 /*after_pass_name=*/pass_name,
189 /*before_pass_name=*/i + 1 >= passes.size()
190 ? kPipelineEnd
191 : passes[i + 1]->name());
192 }
193 RecordPassEndMetadata(*hlo, pass_name, pass_changed);
194 changed |= pass_changed;
195 if (pass_changed) {
196 VLOG(3) << " Pass caused changes " << pass->name();
197 }
198 TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass_name));
199 if (!pass->IsPassPipeline()) {
200 compilation_stats_->EndPass(pass_name);
201 }
202 }
203 return changed;
204 }
205
GetEnabledPasses(const DebugOptions & debug_options)206 std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses(
207 const DebugOptions& debug_options) {
208 if (debug_options.xla_disable_all_hlo_passes()) {
209 VLOG(1) << "*All* passes disabled by --xla_disable_all_hlo_passes.";
210 return {};
211 }
212
213 absl::flat_hash_set<std::string> disabled_pass_names(
214 debug_options.xla_disable_hlo_passes().begin(),
215 debug_options.xla_disable_hlo_passes().end());
216
217 absl::flat_hash_set<std::string> enabled_pass_names(
218 debug_options.xla_enable_hlo_passes_only().begin(),
219 debug_options.xla_enable_hlo_passes_only().end());
220
221 if (!disabled_pass_names.empty()) {
222 VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
223 << absl::StrJoin(disabled_pass_names, ", ");
224 }
225
226 if (!enabled_pass_names.empty()) {
227 VLOG(1) << "Passes enabled by --xla_enable_hlo_passes_only: "
228 << absl::StrJoin(enabled_pass_names, ", ");
229 }
230
231 CHECK(disabled_pass_names.empty() || enabled_pass_names.empty());
232
233 std::vector<HloPassInterface*> enabled_passes;
234 if (!enabled_pass_names.empty()) {
235 for (auto& pass : passes_) {
236 if (enabled_pass_names.contains(pass->name())) {
237 enabled_passes.push_back(pass.get());
238 }
239 }
240 } else {
241 for (auto& pass : passes_) {
242 if (!disabled_pass_names.contains(pass->name())) {
243 enabled_passes.push_back(pass.get());
244 }
245 }
246 }
247 return enabled_passes;
248 }
249
MaybeDumpHloAndSaveFilenames(HloModule & module,absl::string_view after_pass_name,absl::string_view before_pass_name)250 void HloPassPipeline::MaybeDumpHloAndSaveFilenames(
251 HloModule& module, absl::string_view after_pass_name,
252 absl::string_view before_pass_name) {
253 for (const std::string& filename : DumpHloModuleBetweenPassesIfEnabled(
254 name(), before_pass_name, after_pass_name, module)) {
255 Status status = module.metadata()->add_current_pass_dump_filename(filename);
256 if (!status.ok()) {
257 LOG(FATAL) << status;
258 }
259 }
260 }
261
MaybeDumpHloAndSaveFilenames(HloModuleGroup & module_group,absl::string_view after_pass_name,absl::string_view before_pass_name)262 void HloPassPipeline::MaybeDumpHloAndSaveFilenames(
263 HloModuleGroup& module_group, absl::string_view after_pass_name,
264 absl::string_view before_pass_name) {
265 for (HloModule* module : module_group.modules()) {
266 MaybeDumpHloAndSaveFilenames(*module, after_pass_name, before_pass_name);
267 }
268 }
269
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)270 StatusOr<bool> HloPassPipeline::Run(
271 HloModule* module,
272 const absl::flat_hash_set<absl::string_view>& execution_threads) {
273 run_called_ = true;
274
275 VLOG(1) << "Running HLO pass pipeline on module " << module->name() << ": "
276 << name();
277
278 return RunPassesInternal(module, module->config().debug_options(),
279 execution_threads);
280 }
281
RunOnModuleGroup(HloModuleGroup * module_group,const absl::flat_hash_set<absl::string_view> & execution_threads)282 StatusOr<bool> HloPassPipeline::RunOnModuleGroup(
283 HloModuleGroup* module_group,
284 const absl::flat_hash_set<absl::string_view>& execution_threads) {
285 run_called_ = true;
286
287 VLOG(1) << "Running HLO pass pipeline on module group "
288 << module_group->name() << ": " << name();
289
290 if (module_group->modules().empty()) {
291 VLOG(1) << "Module group is empty. Nothing to do.";
292 return false;
293 }
294
295 return RunPassesInternal(module_group,
296 module_group->module(0).config().debug_options(),
297 execution_threads);
298 }
299
300 } // namespace xla
301