xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/compiler_functor.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/Analysis/TargetLibraryInfo.h"
27 #include "llvm/Analysis/TargetTransformInfo.h"
28 #include "llvm/IR/LegacyPassManager.h"
29 #include "llvm/IR/Verifier.h"
30 #include "llvm/MC/MCContext.h"
31 #include "llvm/Object/ObjectFile.h"
32 #include "llvm/Support/SmallVectorMemoryBuffer.h"
33 #include "llvm/Support/raw_ostream.h"
34 #include "llvm/Target/TargetMachine.h"
35 #include "llvm/Transforms/IPO.h"
36 #include "llvm/Transforms/IPO/AlwaysInliner.h"
37 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
38 #include "llvm/Transforms/Instrumentation.h"
39 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
40 #include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
42 #include "tensorflow/compiler/xla/statusor.h"
43 #include "tensorflow/compiler/xla/types.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/core/platform/logging.h"
46 
47 namespace xla {
48 namespace cpu {
49 
50 /* Create filtered versions of the LLVM Pass Managers to filter out some
51 of the expensive passes.
52 Profiling:
53    learning/brain/google/xla/benchmarks:inception_cpu_benchmark
54    learning/brain/google/xla/benchmarks:cifarnet
55 pointed to LICM and IndVarSimplify as the hottest passes.
56 LICM is known to exhibit O(n^2) time in the number of instructions.
57 IndVarSimplify is slow due to SCEV. If loops are emitted in canonical form,
58 this pass is not necessary.
59 Disabling these as a starting point.
60 */
61 // TODO(b/64227304) Creating a custom pass pipeline will replace this.
62 
63 namespace {
64 class FilteredPassManager : public llvm::legacy::PassManager {
65  public:
FilteredPassManager(bool disable_expensive_passes)66   explicit FilteredPassManager(bool disable_expensive_passes)
67       : disable_expensive_passes_(disable_expensive_passes) {}
add(llvm::Pass * p)68   void add(llvm::Pass* p) override {
69     // Disable all the loop unroll passes in the pipeline if
70     // `disable_expensive_passes_` is true (TODO: Maybe we should use
71     // `builder.DisableUnrollLoops` for this legacy feature?). Disable only the
72     // early loop full unroll pass, otherwise. The early loop full unroll pass
73     // applies excesive unrolling in statically bounded low trip-count loops,
74     // which are very common in XLA. It also creates a strong dependency on the
75     // SLP vectorizer to produce all the vector code, since the loops are fully
76     // unrolled. By disabling it, the Loop Vectorizer would have an opportunity
77     // to vectorize the code. A later loop unroll pass will still unroll the
78     // loops before SLP for those cases missed by the Loop Vectorizer.
79     constexpr unsigned loop_full_unroll_pos = 0;
80     if (p->getPassName().contains("Unroll loops") &&
81         (disable_expensive_passes_ ||
82          num_unroll_passes_ == loop_full_unroll_pos)) {
83       ++num_unroll_passes_;
84       delete p;
85       return;
86     }
87 
88     llvm::legacy::PassManager::add(p);
89   }
90 
91  private:
92   unsigned num_unroll_passes_ = 0;
93   bool disable_expensive_passes_;
94 };
95 }  // anonymous namespace
96 
operator ()(llvm::Module & module)97 llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> CompilerFunctor::operator()(
98     llvm::Module& module) {
99   FilteredPassManager module_passes(disable_expensive_passes_);
100   llvm::legacy::FunctionPassManager function_passes(&module);
101 
102   VLOG(2) << "IR before optimizations";
103   XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module));
104 
105   if (pre_optimization_hook_) {
106     pre_optimization_hook_(module);
107   }
108 
109   if (dfsan_enabled_) {
110     module_passes.add(
111         llvm::createDataFlowSanitizerLegacyPassPass(dfsan_abi_list_files_));
112   }
113 
114   // Add the appropriate TargetLibraryInfo and TargetTransformInfo.
115   AddTargetInfoPasses(&module_passes);
116 
117   // Build up optimization pipeline.
118   if (optimize_for_size_) {
119     // Optimizing for size turns on -O2 level optimizations.
120     //
121     // TODO(b/64153864): Although the code generator supports size_level = 2 to
122     // turn on more aggressive code size optimizations than size_level = 1, we
123     // pass size_level = 1 because in many cases a size_level of 2 does
124     // worse. Investigate why.
125     AddOptimizationPasses(&module_passes, &function_passes, /*opt_level=*/2,
126                           /*size_level=*/1);
127   } else {
128     AddOptimizationPasses(&module_passes, &function_passes,
129                           /*opt_level=*/opt_level_, /*size_level=*/0);
130   }
131 
132   // Run optimization passes on module.
133   function_passes.doInitialization();
134 
135   CHECK(!llvm::verifyModule(module, &llvm::dbgs()));
136 
137   for (auto func = module.begin(); func != module.end(); ++func) {
138     function_passes.run(*func);
139   }
140   function_passes.doFinalization();
141   module_passes.run(module);
142 
143   CHECK(!llvm::verifyModule(module, &llvm::dbgs()));
144 
145   runtime::RewriteIRRuntimeFunctions(&module, fast_math_flags_);
146 
147   // Buffer for holding machine code prior to constructing the ObjectFile.
148   llvm::SmallVector<char, 0> stream_buffer;
149   llvm::raw_svector_ostream ostream(stream_buffer);
150 
151   VLOG(2) << "IR after optimizations";
152   XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module));
153 
154   if (post_optimization_hook_) {
155     post_optimization_hook_(module);
156   }
157 
158   // Generate code.
159   llvm::MCContext* mc_context;
160   llvm::legacy::PassManager codegen_passes;
161   target_machine_->addPassesToEmitMC(codegen_passes, mc_context, ostream);
162   codegen_passes.run(module);
163 
164   std::unique_ptr<llvm::MemoryBuffer> memory_buffer(
165       new llvm::SmallVectorMemoryBuffer(std::move(stream_buffer)));
166 
167   if (post_codegen_hook_) {
168     llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>> obj_file =
169         llvm::object::ObjectFile::createObjectFile(*memory_buffer);
170     if (obj_file) {
171       post_codegen_hook_(*obj_file.get());
172     } else {
173       LOG(WARNING) << "Could convert memory buffer to object file!";
174     }
175   }
176 
177   return std::move(memory_buffer);
178 }
179 
VectorFunctionsForTargetLibraryInfoImpl()180 static std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl() {
181   std::vector<llvm::VecDesc> result = {
182       {"tanhf", runtime::kTanhV4F32SymbolName, llvm::ElementCount::getFixed(4)},
183       {"llvm.tanh.f32", runtime::kTanhV4F32SymbolName,
184        llvm::ElementCount::getFixed(4)},
185 
186       {"tanhf", runtime::kTanhV8F32SymbolName, llvm::ElementCount::getFixed(8)},
187       {"llvm.tanh.f32", runtime::kTanhV8F32SymbolName,
188        llvm::ElementCount::getFixed(8)},
189 
190       {"tanhf", runtime::kTanhV16F32SymbolName,
191        llvm::ElementCount::getFixed(16)},
192       {"llvm.tanh.f32", runtime::kTanhV16F32SymbolName,
193        llvm::ElementCount::getFixed(16)},
194 
195       {"expf", runtime::kExpV4F32SymbolName, llvm::ElementCount::getFixed(4)},
196       {"llvm.exp.f32", runtime::kExpV4F32SymbolName,
197        llvm::ElementCount::getFixed(4)},
198 
199       {"expf", runtime::kExpV8F32SymbolName, llvm::ElementCount::getFixed(8)},
200       {"llvm.exp.f32", runtime::kExpV8F32SymbolName,
201        llvm::ElementCount::getFixed(8)},
202 
203       {"expf", runtime::kExpV16F32SymbolName, llvm::ElementCount::getFixed(16)},
204       {"llvm.exp.f32", runtime::kExpV16F32SymbolName,
205        llvm::ElementCount::getFixed(16)},
206 
207       {"logf", runtime::kLogV4F32SymbolName, llvm::ElementCount::getFixed(4)},
208       {"llvm.log.f32", runtime::kLogV4F32SymbolName,
209        llvm::ElementCount::getFixed(4)},
210 
211       {"logf", runtime::kLogV8F32SymbolName, llvm::ElementCount::getFixed(8)},
212       {"llvm.log.f32", runtime::kLogV8F32SymbolName,
213        llvm::ElementCount::getFixed(8)},
214 
215       {"logf", runtime::kLogV16F32SymbolName, llvm::ElementCount::getFixed(16)},
216       {"llvm.log.f32", runtime::kLogV16F32SymbolName,
217        llvm::ElementCount::getFixed(16)},
218   };
219   return result;
220 }
221 
AddTargetInfoPasses(llvm::legacy::PassManagerBase * passes) const222 void CompilerFunctor::AddTargetInfoPasses(
223     llvm::legacy::PassManagerBase* passes) const {
224   llvm::Triple target_triple(target_machine_->getTargetTriple());
225   auto target_library_info_impl =
226       std::make_unique<llvm::TargetLibraryInfoImpl>(target_triple);
227   target_library_info_impl->addVectorizableFunctions(
228       VectorFunctionsForTargetLibraryInfoImpl());
229 
230   passes->add(
231       new llvm::TargetLibraryInfoWrapperPass(*target_library_info_impl));
232   passes->add(createTargetTransformInfoWrapperPass(
233       target_machine_->getTargetIRAnalysis()));
234 }
235 
AddOptimizationPasses(llvm::legacy::PassManagerBase * module_passes,llvm::legacy::FunctionPassManager * function_passes,unsigned opt_level,unsigned size_level) const236 void CompilerFunctor::AddOptimizationPasses(
237     llvm::legacy::PassManagerBase* module_passes,
238     llvm::legacy::FunctionPassManager* function_passes, unsigned opt_level,
239     unsigned size_level) const {
240   llvm::PassManagerBuilder builder;
241   builder.OptLevel = opt_level;
242   builder.SizeLevel = size_level;
243 
244   if (opt_level > 1) {
245     builder.Inliner = llvm::createFunctionInliningPass();
246   } else {
247     // Only inline functions marked with "alwaysinline".
248     builder.Inliner = llvm::createAlwaysInlinerLegacyPass();
249   }
250 
251   builder.DisableUnrollLoops = opt_level == 0;
252   builder.LoopVectorize = opt_level > 0 && size_level == 0;
253   builder.SLPVectorize = opt_level > 1 && size_level == 0;
254 
255   builder.populateFunctionPassManager(*function_passes);
256   builder.populateModulePassManager(*module_passes);
257 }
258 
259 }  // namespace cpu
260 }  // namespace xla
261