1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
17
18 #include <algorithm>
19 #include <iterator>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/Analysis/TargetLibraryInfo.h"
27 #include "llvm/Analysis/TargetTransformInfo.h"
28 #include "llvm/IR/LegacyPassManager.h"
29 #include "llvm/IR/Verifier.h"
30 #include "llvm/MC/MCContext.h"
31 #include "llvm/Object/ObjectFile.h"
32 #include "llvm/Support/SmallVectorMemoryBuffer.h"
33 #include "llvm/Support/raw_ostream.h"
34 #include "llvm/Target/TargetMachine.h"
35 #include "llvm/Transforms/IPO.h"
36 #include "llvm/Transforms/IPO/AlwaysInliner.h"
37 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
38 #include "llvm/Transforms/Instrumentation.h"
39 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
40 #include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
42 #include "tensorflow/compiler/xla/statusor.h"
43 #include "tensorflow/compiler/xla/types.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/core/platform/logging.h"
46
47 namespace xla {
48 namespace cpu {
49
50 /* Create filtered versions of the LLVM Pass Managers to filter out some
51 of the expensive passes.
52 Profiling:
53 learning/brain/google/xla/benchmarks:inception_cpu_benchmark
54 learning/brain/google/xla/benchmarks:cifarnet
55 pointed to LICM and IndVarSimplify as the hottest passes.
56 LICM is known to exhibit O(n^2) time in the number of instructions.
57 IndVarSimplify is slow due to SCEV. If loops are emitted in canonical form,
58 this pass is not necessary.
59 Disabling these as a starting point.
60 */
61 // TODO(b/64227304) Creating a custom pass pipeline will replace this.
62
63 namespace {
64 class FilteredPassManager : public llvm::legacy::PassManager {
65 public:
FilteredPassManager(bool disable_expensive_passes)66 explicit FilteredPassManager(bool disable_expensive_passes)
67 : disable_expensive_passes_(disable_expensive_passes) {}
add(llvm::Pass * p)68 void add(llvm::Pass* p) override {
69 // Disable all the loop unroll passes in the pipeline if
70 // `disable_expensive_passes_` is true (TODO: Maybe we should use
71 // `builder.DisableUnrollLoops` for this legacy feature?). Disable only the
72 // early loop full unroll pass, otherwise. The early loop full unroll pass
73 // applies excesive unrolling in statically bounded low trip-count loops,
74 // which are very common in XLA. It also creates a strong dependency on the
75 // SLP vectorizer to produce all the vector code, since the loops are fully
76 // unrolled. By disabling it, the Loop Vectorizer would have an opportunity
77 // to vectorize the code. A later loop unroll pass will still unroll the
78 // loops before SLP for those cases missed by the Loop Vectorizer.
79 constexpr unsigned loop_full_unroll_pos = 0;
80 if (p->getPassName().contains("Unroll loops") &&
81 (disable_expensive_passes_ ||
82 num_unroll_passes_ == loop_full_unroll_pos)) {
83 ++num_unroll_passes_;
84 delete p;
85 return;
86 }
87
88 llvm::legacy::PassManager::add(p);
89 }
90
91 private:
92 unsigned num_unroll_passes_ = 0;
93 bool disable_expensive_passes_;
94 };
95 } // anonymous namespace
96
operator ()(llvm::Module & module)97 llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> CompilerFunctor::operator()(
98 llvm::Module& module) {
99 FilteredPassManager module_passes(disable_expensive_passes_);
100 llvm::legacy::FunctionPassManager function_passes(&module);
101
102 VLOG(2) << "IR before optimizations";
103 XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module));
104
105 if (pre_optimization_hook_) {
106 pre_optimization_hook_(module);
107 }
108
109 if (dfsan_enabled_) {
110 module_passes.add(
111 llvm::createDataFlowSanitizerLegacyPassPass(dfsan_abi_list_files_));
112 }
113
114 // Add the appropriate TargetLibraryInfo and TargetTransformInfo.
115 AddTargetInfoPasses(&module_passes);
116
117 // Build up optimization pipeline.
118 if (optimize_for_size_) {
119 // Optimizing for size turns on -O2 level optimizations.
120 //
121 // TODO(b/64153864): Although the code generator supports size_level = 2 to
122 // turn on more aggressive code size optimizations than size_level = 1, we
123 // pass size_level = 1 because in many cases a size_level of 2 does
124 // worse. Investigate why.
125 AddOptimizationPasses(&module_passes, &function_passes, /*opt_level=*/2,
126 /*size_level=*/1);
127 } else {
128 AddOptimizationPasses(&module_passes, &function_passes,
129 /*opt_level=*/opt_level_, /*size_level=*/0);
130 }
131
132 // Run optimization passes on module.
133 function_passes.doInitialization();
134
135 CHECK(!llvm::verifyModule(module, &llvm::dbgs()));
136
137 for (auto func = module.begin(); func != module.end(); ++func) {
138 function_passes.run(*func);
139 }
140 function_passes.doFinalization();
141 module_passes.run(module);
142
143 CHECK(!llvm::verifyModule(module, &llvm::dbgs()));
144
145 runtime::RewriteIRRuntimeFunctions(&module, fast_math_flags_);
146
147 // Buffer for holding machine code prior to constructing the ObjectFile.
148 llvm::SmallVector<char, 0> stream_buffer;
149 llvm::raw_svector_ostream ostream(stream_buffer);
150
151 VLOG(2) << "IR after optimizations";
152 XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module));
153
154 if (post_optimization_hook_) {
155 post_optimization_hook_(module);
156 }
157
158 // Generate code.
159 llvm::MCContext* mc_context;
160 llvm::legacy::PassManager codegen_passes;
161 target_machine_->addPassesToEmitMC(codegen_passes, mc_context, ostream);
162 codegen_passes.run(module);
163
164 std::unique_ptr<llvm::MemoryBuffer> memory_buffer(
165 new llvm::SmallVectorMemoryBuffer(std::move(stream_buffer)));
166
167 if (post_codegen_hook_) {
168 llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>> obj_file =
169 llvm::object::ObjectFile::createObjectFile(*memory_buffer);
170 if (obj_file) {
171 post_codegen_hook_(*obj_file.get());
172 } else {
173 LOG(WARNING) << "Could convert memory buffer to object file!";
174 }
175 }
176
177 return std::move(memory_buffer);
178 }
179
VectorFunctionsForTargetLibraryInfoImpl()180 static std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl() {
181 std::vector<llvm::VecDesc> result = {
182 {"tanhf", runtime::kTanhV4F32SymbolName, llvm::ElementCount::getFixed(4)},
183 {"llvm.tanh.f32", runtime::kTanhV4F32SymbolName,
184 llvm::ElementCount::getFixed(4)},
185
186 {"tanhf", runtime::kTanhV8F32SymbolName, llvm::ElementCount::getFixed(8)},
187 {"llvm.tanh.f32", runtime::kTanhV8F32SymbolName,
188 llvm::ElementCount::getFixed(8)},
189
190 {"tanhf", runtime::kTanhV16F32SymbolName,
191 llvm::ElementCount::getFixed(16)},
192 {"llvm.tanh.f32", runtime::kTanhV16F32SymbolName,
193 llvm::ElementCount::getFixed(16)},
194
195 {"expf", runtime::kExpV4F32SymbolName, llvm::ElementCount::getFixed(4)},
196 {"llvm.exp.f32", runtime::kExpV4F32SymbolName,
197 llvm::ElementCount::getFixed(4)},
198
199 {"expf", runtime::kExpV8F32SymbolName, llvm::ElementCount::getFixed(8)},
200 {"llvm.exp.f32", runtime::kExpV8F32SymbolName,
201 llvm::ElementCount::getFixed(8)},
202
203 {"expf", runtime::kExpV16F32SymbolName, llvm::ElementCount::getFixed(16)},
204 {"llvm.exp.f32", runtime::kExpV16F32SymbolName,
205 llvm::ElementCount::getFixed(16)},
206
207 {"logf", runtime::kLogV4F32SymbolName, llvm::ElementCount::getFixed(4)},
208 {"llvm.log.f32", runtime::kLogV4F32SymbolName,
209 llvm::ElementCount::getFixed(4)},
210
211 {"logf", runtime::kLogV8F32SymbolName, llvm::ElementCount::getFixed(8)},
212 {"llvm.log.f32", runtime::kLogV8F32SymbolName,
213 llvm::ElementCount::getFixed(8)},
214
215 {"logf", runtime::kLogV16F32SymbolName, llvm::ElementCount::getFixed(16)},
216 {"llvm.log.f32", runtime::kLogV16F32SymbolName,
217 llvm::ElementCount::getFixed(16)},
218 };
219 return result;
220 }
221
AddTargetInfoPasses(llvm::legacy::PassManagerBase * passes) const222 void CompilerFunctor::AddTargetInfoPasses(
223 llvm::legacy::PassManagerBase* passes) const {
224 llvm::Triple target_triple(target_machine_->getTargetTriple());
225 auto target_library_info_impl =
226 std::make_unique<llvm::TargetLibraryInfoImpl>(target_triple);
227 target_library_info_impl->addVectorizableFunctions(
228 VectorFunctionsForTargetLibraryInfoImpl());
229
230 passes->add(
231 new llvm::TargetLibraryInfoWrapperPass(*target_library_info_impl));
232 passes->add(createTargetTransformInfoWrapperPass(
233 target_machine_->getTargetIRAnalysis()));
234 }
235
AddOptimizationPasses(llvm::legacy::PassManagerBase * module_passes,llvm::legacy::FunctionPassManager * function_passes,unsigned opt_level,unsigned size_level) const236 void CompilerFunctor::AddOptimizationPasses(
237 llvm::legacy::PassManagerBase* module_passes,
238 llvm::legacy::FunctionPassManager* function_passes, unsigned opt_level,
239 unsigned size_level) const {
240 llvm::PassManagerBuilder builder;
241 builder.OptLevel = opt_level;
242 builder.SizeLevel = size_level;
243
244 if (opt_level > 1) {
245 builder.Inliner = llvm::createFunctionInliningPass();
246 } else {
247 // Only inline functions marked with "alwaysinline".
248 builder.Inliner = llvm::createAlwaysInlinerLegacyPass();
249 }
250
251 builder.DisableUnrollLoops = opt_level == 0;
252 builder.LoopVectorize = opt_level > 0 && size_level == 0;
253 builder.SLPVectorize = opt_level > 1 && size_level == 0;
254
255 builder.populateFunctionPassManager(*function_passes);
256 builder.populateModulePassManager(*module_passes);
257 }
258
259 } // namespace cpu
260 } // namespace xla
261