xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/codegen/fuser/cpu/fused_kernel.h>
2 
3 #include <ATen/DynamicLibrary.h>
4 #include <ATen/code_template.h>
5 #include <c10/util/Exception.h>
6 #include <torch/csrc/jit/codegen/fuser/compiler.h>
7 #include <torch/csrc/jit/codegen/fuser/cpu/temp_file.h>
8 #include <optional>
9 
10 #include <cstdlib>
11 #include <iostream>
12 #include <string>
13 
14 namespace torch {
15 namespace jit {
16 namespace fuser {
17 namespace cpu {
18 
19 #ifdef _MSC_VER
getTempPath()20 static const std::string getTempPath() {
21   wchar_t lpTempPathBuffer[MAX_PATH];
22 
23   DWORD dwRetVal = GetTempPathW(
24       MAX_PATH, // length of the buffer
25       lpTempPathBuffer); // buffer for path
26 
27   TORCH_CHECK(dwRetVal < MAX_PATH && dwRetVal != 0, "GetTempPath failed.");
28 
29   return std::string(c10::u16u8(lpTempPathBuffer));
30 }
31 static const std::string temp_dir = getTempPath();
32 static const std::string so_template = temp_dir + "pytorch_fuserXXXXXX.dll";
33 static const std::string cpp_template = temp_dir + "pytorch_fuserXXXXXX.cpp";
34 static const std::string check_exists_string = "where ${program} > nul 2> nul";
35 static std::vector<std::wstring> env_list;
36 constexpr int so_suffix_len = 4;
37 constexpr int cpp_suffix_len = 4;
38 #else
39 static const std::string so_template = "/tmp/pytorch_fuserXXXXXX.so";
40 static const std::string cpp_template = "/tmp/pytorch_fuserXXXXXX.cpp";
41 static const std::string check_exists_string = "which ${program} > /dev/null";
42 constexpr int so_suffix_len = 3;
43 constexpr int cpp_suffix_len = 4;
44 #endif
45 
46 intptr_t run(const std::string& cmd);
47 
programExists(const std::string & program)48 static bool programExists(const std::string& program) {
49   std::stringstream ss;
50   c10::printQuotedString(ss, program);
51   at::jit::TemplateEnv env;
52   env.s("program", ss.str());
53   std::string cmd = format(check_exists_string, env);
54 #ifdef _MSC_VER
55   return (run(cmd.c_str()) == 0);
56 #else
57   return (system(cmd.c_str()) == 0);
58 #endif
59 }
60 
61 #ifdef _MSC_VER
exec(const std::wstring & cmd)62 std::optional<std::wstring> exec(const std::wstring& cmd) {
63   std::array<wchar_t, 128> buffer;
64   std::wstring result;
65   std::unique_ptr<FILE, decltype(&_pclose)> pipe(
66       _wpopen(cmd.c_str(), L"r"), _pclose);
67   if (!pipe) {
68     return std::nullopt;
69   }
70   while (fgetws(buffer.data(), static_cast<int>(buffer.size()), pipe.get()) !=
71          nullptr) {
72     result += buffer.data();
73   }
74   return result;
75 }
76 
rtrim(std::wstring & s,const wchar_t * t=L" \\t\\n\\r\\f\\v")77 inline std::wstring& rtrim(std::wstring& s, const wchar_t* t = L" \t\n\r\f\v") {
78   s.erase(s.find_last_not_of(t) + 1);
79   return s;
80 }
81 
activate()82 void activate() {
83   wchar_t* root = nullptr;
84   std::wstring cmd;
85   std::optional<std::wstring> exec_out;
86   std::wstring path;
87   std::wstring vcruntime_plat;
88   std::wstring envvars;
89 
90   // Checking whether the environment is already activated
91   if (_wgetenv(L"VSCMD_ARG_TGT_ARCH")) {
92     return;
93   }
94 
95   // Getting `ProgramFiles` through environment variable queries
96   root = _wgetenv(L"ProgramFiles(x86)");
97   if (!root) {
98     root = _wgetenv(L"ProgramFiles");
99   }
100   if (!root) {
101     return;
102   }
103 
104   // Getting VS 2017 installation path through `vswhere`
105   cmd = L"\"" + std::wstring(root) +
106       L"\\Microsoft Visual Studio\\Installer\\vswhere.exe\""
107       L" -latest -prerelease -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath";
108   exec_out = exec(cmd);
109   if (!exec_out) {
110     return;
111   }
112   path = *exec_out;
113   rtrim(path);
114 
115   // Checking whether the activation script `vcvarsall.bat` exists
116   path += L"\\VC\\Auxiliary\\Build";
117   struct _stati64 st;
118   if (_wstati64(path.c_str(), &st) == -1 || !(st.st_mode & _S_IFDIR)) {
119     return;
120   }
121   path += L"\\vcvarsall.bat";
122   if (_waccess(path.c_str(), 0) == -1) {
123     return;
124   }
125 
126   // Determining current platform
127   if (sizeof(void*) == 8) {
128     vcruntime_plat = L"x64";
129   } else {
130     vcruntime_plat = L"x86";
131   }
132 
133   // Getting environment variables after activating VS development shell
134   cmd = L"\"" + path + L"\" " + vcruntime_plat + L">NUL && set";
135   exec_out = exec(cmd);
136   if (!exec_out) {
137     return;
138   }
139   envvars = *exec_out;
140 
141   // Setting environment variables to the current environment
142   std::wistringstream f(envvars);
143   std::wstring envvar;
144   while (getline(f, envvar, L'\n')) {
145     env_list.push_back(envvar);
146   }
147 }
148 
run(const std::string & cmd)149 intptr_t run(const std::string& cmd) {
150   // Getting the path of `cmd.exe`
151   wchar_t* comspec = _wgetenv(L"COMSPEC");
152   if (!comspec) {
153     comspec = L"C:\\Windows\\System32\\cmd.exe";
154   }
155   // Constructing the command line
156   auto wCmd = c10::u8u16(cmd);
157   const wchar_t* a[] = {L"/c", wCmd.c_str(), nullptr};
158   // Constructing the env array
159   // If `env_list` is not empty, then add char pointers ending with nullptr.
160   // Otherwise, it will be nullptr, which implies the default env.
161   std::vector<const wchar_t*> e;
162   if (!env_list.empty()) {
163     for (auto& s : env_list) {
164       e.push_back(s.c_str());
165     }
166     e.push_back(nullptr);
167   }
168   // Running the command
169   intptr_t r = _wspawnve(_P_WAIT, comspec, a, e.data());
170   return r;
171 }
172 #endif
173 
174 // A single compiler config is accessed through getConfig() (below)
175 // Controls compilation options and may be updated based on the result
176 // of compilation attempts.
177 struct CompilerConfig {
CompilerConfigtorch::jit::fuser::cpu::CompilerConfig178   CompilerConfig() {
179     const char* cxx_env = getenv("CXX");
180     if (cxx_env != nullptr) {
181       cxx = cxx_env;
182     }
183 
184 #ifdef _MSC_VER
185     activate();
186 #endif
187 
188     if (!programExists(cxx)) {
189       TORCH_WARN("Compiler passed via CXX envvar does not exist!");
190       cxx = "";
191     }
192   }
193 
194   ~CompilerConfig() = default;
195 
196 #ifdef _MSC_VER
197   std::string cxx = "cl";
198   const std::string openmp_flags = "/openmp";
199 #elif defined(__clang__)
200   std::string cxx = "clang++";
201   const std::string openmp_flags = "-fopenmp";
202 #else
203   std::string cxx = "g++";
204   const std::string openmp_flags = "-fopenmp";
205 #endif
206 // Set openmp to true only if PyTorch is compiled with OpenMP support
207 // OpenMP is typically not available on MacOS platform
208 #if defined(_OPENMP)
209   bool openmp = true;
210 #else
211   bool openmp = false;
212 #endif
213 };
214 
getConfig()215 static CompilerConfig& getConfig() {
216   static CompilerConfig config;
217   return config;
218 }
219 
220 // NB: -march=native not supported on PPC64 g++.  It's a bit annoying
221 // to do a configure-style test to decide whether or not the g++
222 // actually supports it or not, so we heuristically use the host
223 // compiler to predict if the runtime compiler supports the option we
224 // want.  This probably won't work if you're cross-compiling.
225 // NB: -march=native is disabled because it has caused problems where
226 // compiler and assembler do not agree on what native instruction they
227 // understand for AVX512. When we need better CPU performance this
228 // optimization can be re-enabled by tracking down the platforms where
229 // this error occurs and only selectively disabling it.
230 #if (defined(_MSC_VER) && !defined(_M_ARM64))
231 // According to https://stackoverflow.com/a/29178079, we are able to
232 // detect which arch level is supported by the vectorizer using
233 // the macro __isa_available. It is added during runtime.
234 // The result of __isa_available and the corresponding arch:
235 //  AVX       4
236 //  AVX2      5
237 //  AVX512    6
238 extern "C" int __isa_available;
getArchFlags()239 static std::string getArchFlags() {
240   if (__isa_available >= 6) {
241     return "/arch:AVX512";
242   } else if (__isa_available >= 5) {
243     return "/arch:AVX2";
244   } else if (__isa_available >= 4) {
245     return "/arch:AVX";
246   } else {
247     return "";
248   }
249 }
250 static const std::string arch_flags = getArchFlags();
251 static const std::string compile_string = "cd /D \"" + temp_dir +
252     "\" && "
253     "${cxx} /nologo /MD /O2 " +
254     arch_flags +
255     " /LD /EHsc "
256     "${fopenmp} \"${cpp_file}\" /link /out:\"${so_file}\"";
257 #else
258 static const std::string compile_string =
259     "\"${cxx}\" -O3 -g "
260 #ifndef __PPC64__
261 //  "-march=native "
262 #endif
263     "-std=c++17 -fPIC ${fopenmp} -shared \"${cpp_file}\" -o \"${so_file}\" -lm";
264 #endif
runCompiler(const std::string & cpp_file,const std::string & so_file)265 static void runCompiler(
266     const std::string& cpp_file,
267     const std::string& so_file) {
268   auto& config = getConfig();
269   TORCH_CHECK(
270       !config.cxx.empty(),
271       "Failed to compile a fused CPU kernel: Compiler not found");
272   at::jit::TemplateEnv env;
273   env.s("cxx", config.cxx);
274   env.s("fopenmp", config.openmp ? config.openmp_flags : "");
275   env.s("cpp_file", cpp_file);
276   env.s("so_file", so_file);
277   std::string result = format(compile_string, env);
278 #ifdef _MSC_VER
279   intptr_t r = run(result);
280 #else
281   int r = system(result.c_str());
282 #endif
283   if (config.openmp && r != 0) {
284     std::cerr
285         << "warning: pytorch jit fuser failed to compile with openmp, trying without it...\n";
286     config.openmp = false; // disable for future compiles
287     return runCompiler(cpp_file, so_file);
288   }
289   TORCH_CHECK(r == 0, "Failed to compile a fused CPU kernel");
290 }
291 
292 #ifdef _MSC_VER
293 static const std::string disas_string =
294     "dumpbin /DISASM:NOBYTES \"${so_file}\"";
295 #else
296 static const std::string disas_string = "objdump -M  intel -d \"${so_file}\"";
297 #endif
disas(const std::string & so_file)298 static void disas(const std::string& so_file) {
299   at::jit::TemplateEnv env;
300   env.s("so_file", so_file);
301   std::string cmd = format(disas_string, env);
302   int r = system(cmd.c_str());
303   AT_ASSERT(r == 0);
304 }
305 
FusedKernelCPU(std::string name,std::string code,std::vector<TensorDesc> input_desc,std::vector<TensorDesc> output_desc,std::vector<PartitionDesc> chunk_desc,std::vector<PartitionDesc> concat_desc,bool has_random)306 FusedKernelCPU::FusedKernelCPU(
307     std::string name,
308     std::string code,
309     std::vector<TensorDesc> input_desc,
310     std::vector<TensorDesc> output_desc,
311     std::vector<PartitionDesc> chunk_desc,
312     std::vector<PartitionDesc> concat_desc,
313     bool has_random)
314     : FusedKernel(
315           std::move(name),
316           std::move(code),
317           std::move(input_desc),
318           std::move(output_desc),
319           std::move(chunk_desc),
320           std::move(concat_desc),
321           has_random) {
322   TempFile so_file(so_template, so_suffix_len);
323   TempFile cpp_file(cpp_template, cpp_suffix_len);
324   cpp_file.write(code_);
325   cpp_file.sync();
326 #ifdef _MSC_VER
327   so_file.close();
328   cpp_file.close();
329 #endif
330   runCompiler(cpp_file.name(), so_file.name());
331   if (debugFuser() >= 2)
332     disas(so_file.name());
333   so_lib = std::make_unique<at::DynamicLibrary>(so_file.name().c_str());
334 #pragma GCC diagnostic ignored "-Wpedantic"
335   kernel =
336       reinterpret_cast<void (*)(uint32_t, void**)>(so_lib->sym(name_.c_str()));
337 #pragma GCC diagnostic pop
338 }
339 
createFusionKernel(int16_t device,std::string name,std::string code,std::vector<TensorDesc> input_desc,std::vector<TensorDesc> output_desc,std::vector<PartitionDesc> chunk_desc,std::vector<PartitionDesc> concat_desc,bool has_random)340 static std::shared_ptr<FusedKernel> createFusionKernel(
341     int16_t device,
342     std::string name,
343     std::string code,
344     std::vector<TensorDesc> input_desc,
345     std::vector<TensorDesc> output_desc,
346     std::vector<PartitionDesc> chunk_desc,
347     std::vector<PartitionDesc> concat_desc,
348     bool has_random) {
349   return std::make_shared<FusedKernelCPU>(
350       std::move(name),
351       std::move(code),
352       std::move(input_desc),
353       std::move(output_desc),
354       std::move(chunk_desc),
355       std::move(concat_desc),
356       has_random);
357 }
358 
359 RegisterFusionBackend reg(DeviceType::CPU, createFusionKernel);
360 } // namespace cpu
361 } // namespace fuser
362 } // namespace jit
363 } // namespace torch
364