xref: /aosp_15_r20/external/pytorch/aten/src/ATen/Context.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <ATen/Config.h>
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/Context.h>
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker #include <c10/core/CPUAllocator.h>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker #include <algorithm>
8*da0073e9SAndroid Build Coastguard Worker #include <cctype>
9*da0073e9SAndroid Build Coastguard Worker #include <string>
10*da0073e9SAndroid Build Coastguard Worker #include <stdexcept>
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker #include <ATen/cpu/FlushDenormal.h>
13*da0073e9SAndroid Build Coastguard Worker 
14*da0073e9SAndroid Build Coastguard Worker #ifdef USE_FBGEMM
15*da0073e9SAndroid Build Coastguard Worker #include <fbgemm/Fbgemm.h>
16*da0073e9SAndroid Build Coastguard Worker #endif // USE_FBGEMM
17*da0073e9SAndroid Build Coastguard Worker #if defined(__aarch64__) && !defined(C10_MOBILE)
18*da0073e9SAndroid Build Coastguard Worker #include <cpuinfo.h>
19*da0073e9SAndroid Build Coastguard Worker #endif
20*da0073e9SAndroid Build Coastguard Worker 
21*da0073e9SAndroid Build Coastguard Worker namespace at {
22*da0073e9SAndroid Build Coastguard Worker 
23*da0073e9SAndroid Build Coastguard Worker Context::Context() = default;
24*da0073e9SAndroid Build Coastguard Worker 
25*da0073e9SAndroid Build Coastguard Worker // TODO: This could be bad juju if someone calls globalContext() in the
26*da0073e9SAndroid Build Coastguard Worker // destructor of an object with static lifetime.
globalContext()27*da0073e9SAndroid Build Coastguard Worker Context& globalContext() {
28*da0073e9SAndroid Build Coastguard Worker   static Context globalContext_;
29*da0073e9SAndroid Build Coastguard Worker   return globalContext_;
30*da0073e9SAndroid Build Coastguard Worker }
31*da0073e9SAndroid Build Coastguard Worker 
32*da0073e9SAndroid Build Coastguard Worker // NB: This method is *purely* whether or not a user requested
33*da0073e9SAndroid Build Coastguard Worker // that CuDNN was enabled, it doesn't actually say anything about
34*da0073e9SAndroid Build Coastguard Worker // whether or not CuDNN is actually usable.
userEnabledCuDNN() const35*da0073e9SAndroid Build Coastguard Worker bool Context::userEnabledCuDNN() const {
36*da0073e9SAndroid Build Coastguard Worker   return enabled_cudnn;
37*da0073e9SAndroid Build Coastguard Worker }
38*da0073e9SAndroid Build Coastguard Worker 
setUserEnabledCuDNN(bool e)39*da0073e9SAndroid Build Coastguard Worker void Context::setUserEnabledCuDNN(bool e) {
40*da0073e9SAndroid Build Coastguard Worker   enabled_cudnn = e;
41*da0073e9SAndroid Build Coastguard Worker }
42*da0073e9SAndroid Build Coastguard Worker 
userEnabledMkldnn() const43*da0073e9SAndroid Build Coastguard Worker bool Context::userEnabledMkldnn() const {
44*da0073e9SAndroid Build Coastguard Worker   return enabled_mkldnn;
45*da0073e9SAndroid Build Coastguard Worker }
46*da0073e9SAndroid Build Coastguard Worker 
setUserEnabledMkldnn(bool e)47*da0073e9SAndroid Build Coastguard Worker void Context::setUserEnabledMkldnn(bool e) {
48*da0073e9SAndroid Build Coastguard Worker   enabled_mkldnn = e;
49*da0073e9SAndroid Build Coastguard Worker }
50*da0073e9SAndroid Build Coastguard Worker 
deterministicCuDNN() const51*da0073e9SAndroid Build Coastguard Worker bool Context::deterministicCuDNN() const {
52*da0073e9SAndroid Build Coastguard Worker   return deterministic_cudnn;
53*da0073e9SAndroid Build Coastguard Worker }
54*da0073e9SAndroid Build Coastguard Worker 
setDeterministicCuDNN(bool b)55*da0073e9SAndroid Build Coastguard Worker void Context::setDeterministicCuDNN(bool b) {
56*da0073e9SAndroid Build Coastguard Worker   deterministic_cudnn = b;
57*da0073e9SAndroid Build Coastguard Worker }
58*da0073e9SAndroid Build Coastguard Worker 
deterministicAlgorithms() const59*da0073e9SAndroid Build Coastguard Worker bool Context::deterministicAlgorithms() const {
60*da0073e9SAndroid Build Coastguard Worker   return _deterministic_algorithms;
61*da0073e9SAndroid Build Coastguard Worker }
62*da0073e9SAndroid Build Coastguard Worker 
deterministicAlgorithmsWarnOnly() const63*da0073e9SAndroid Build Coastguard Worker bool Context::deterministicAlgorithmsWarnOnly() const {
64*da0073e9SAndroid Build Coastguard Worker   return _deterministic_algorithms_warn_only;
65*da0073e9SAndroid Build Coastguard Worker }
66*da0073e9SAndroid Build Coastguard Worker 
setDeterministicAlgorithms(bool b,bool warn_only=false)67*da0073e9SAndroid Build Coastguard Worker void Context::setDeterministicAlgorithms(bool b, bool warn_only=false) {
68*da0073e9SAndroid Build Coastguard Worker   _deterministic_algorithms = b;
69*da0073e9SAndroid Build Coastguard Worker   _deterministic_algorithms_warn_only = warn_only;
70*da0073e9SAndroid Build Coastguard Worker }
71*da0073e9SAndroid Build Coastguard Worker 
deterministicFillUninitializedMemory() const72*da0073e9SAndroid Build Coastguard Worker bool Context::deterministicFillUninitializedMemory() const {
73*da0073e9SAndroid Build Coastguard Worker   return _deterministic_fill_uninitialized_memory;
74*da0073e9SAndroid Build Coastguard Worker }
75*da0073e9SAndroid Build Coastguard Worker 
setDeterministicFillUninitializedMemory(bool b)76*da0073e9SAndroid Build Coastguard Worker void Context::setDeterministicFillUninitializedMemory(bool b) {
77*da0073e9SAndroid Build Coastguard Worker   _deterministic_fill_uninitialized_memory = b;
78*da0073e9SAndroid Build Coastguard Worker }
79*da0073e9SAndroid Build Coastguard Worker 
alertNotDeterministic(c10::string_view const & caller)80*da0073e9SAndroid Build Coastguard Worker void Context::alertNotDeterministic(c10::string_view const& caller) {
81*da0073e9SAndroid Build Coastguard Worker   if (globalContext().deterministicAlgorithms()) {
82*da0073e9SAndroid Build Coastguard Worker     if (globalContext().deterministicAlgorithmsWarnOnly()) {
83*da0073e9SAndroid Build Coastguard Worker       TORCH_WARN(
84*da0073e9SAndroid Build Coastguard Worker         caller, " does not have a deterministic implementation, but you set "
85*da0073e9SAndroid Build Coastguard Worker         "'torch.use_deterministic_algorithms(True, warn_only=True)'. "
86*da0073e9SAndroid Build Coastguard Worker         "You can file an issue at https://github.com/pytorch/pytorch/issues "
87*da0073e9SAndroid Build Coastguard Worker         "to help us prioritize adding deterministic support for this operation.");
88*da0073e9SAndroid Build Coastguard Worker     } else {
89*da0073e9SAndroid Build Coastguard Worker       TORCH_CHECK(false,
90*da0073e9SAndroid Build Coastguard Worker         caller, " does not have a deterministic implementation, but you set "
91*da0073e9SAndroid Build Coastguard Worker         "'torch.use_deterministic_algorithms(True)'. You can turn off "
92*da0073e9SAndroid Build Coastguard Worker         "determinism just for this operation, or you can use the "
93*da0073e9SAndroid Build Coastguard Worker         "'warn_only=True' option, if that's acceptable for your application. "
94*da0073e9SAndroid Build Coastguard Worker         "You can also file an issue at https://github.com/pytorch/pytorch/issues "
95*da0073e9SAndroid Build Coastguard Worker         "to help us prioritize adding deterministic support for this operation.");
96*da0073e9SAndroid Build Coastguard Worker     }
97*da0073e9SAndroid Build Coastguard Worker   }
98*da0073e9SAndroid Build Coastguard Worker }
99*da0073e9SAndroid Build Coastguard Worker 
userEnabledNNPACK() const100*da0073e9SAndroid Build Coastguard Worker bool Context::userEnabledNNPACK() const {
101*da0073e9SAndroid Build Coastguard Worker   return enabled_nnpack;
102*da0073e9SAndroid Build Coastguard Worker }
103*da0073e9SAndroid Build Coastguard Worker 
setUserEnabledNNPACK(bool e)104*da0073e9SAndroid Build Coastguard Worker void Context::setUserEnabledNNPACK(bool e) {
105*da0073e9SAndroid Build Coastguard Worker   enabled_nnpack = e;
106*da0073e9SAndroid Build Coastguard Worker }
107*da0073e9SAndroid Build Coastguard Worker 
allowTF32CuDNN() const108*da0073e9SAndroid Build Coastguard Worker bool Context::allowTF32CuDNN() const {
109*da0073e9SAndroid Build Coastguard Worker   return allow_tf32_cudnn;
110*da0073e9SAndroid Build Coastguard Worker }
111*da0073e9SAndroid Build Coastguard Worker 
setAllowTF32CuDNN(bool b)112*da0073e9SAndroid Build Coastguard Worker void Context::setAllowTF32CuDNN(bool b) {
113*da0073e9SAndroid Build Coastguard Worker   allow_tf32_cudnn = b;
114*da0073e9SAndroid Build Coastguard Worker }
115*da0073e9SAndroid Build Coastguard Worker 
userEnabledFlashSDP() const116*da0073e9SAndroid Build Coastguard Worker bool Context::userEnabledFlashSDP() const {
117*da0073e9SAndroid Build Coastguard Worker   return enabled_flashSDP;
118*da0073e9SAndroid Build Coastguard Worker }
119*da0073e9SAndroid Build Coastguard Worker 
setSDPUseFlash(bool e)120*da0073e9SAndroid Build Coastguard Worker void Context::setSDPUseFlash(bool e) {
121*da0073e9SAndroid Build Coastguard Worker   enabled_flashSDP = e;
122*da0073e9SAndroid Build Coastguard Worker }
123*da0073e9SAndroid Build Coastguard Worker 
userEnabledMemEfficientSDP() const124*da0073e9SAndroid Build Coastguard Worker bool Context::userEnabledMemEfficientSDP() const {
125*da0073e9SAndroid Build Coastguard Worker   return enabled_mem_efficientSDP;
126*da0073e9SAndroid Build Coastguard Worker }
127*da0073e9SAndroid Build Coastguard Worker 
setSDPUseMemEfficient(bool e)128*da0073e9SAndroid Build Coastguard Worker void Context::setSDPUseMemEfficient(bool e) {
129*da0073e9SAndroid Build Coastguard Worker   enabled_mem_efficientSDP = e;
130*da0073e9SAndroid Build Coastguard Worker }
131*da0073e9SAndroid Build Coastguard Worker 
userEnabledMathSDP() const132*da0073e9SAndroid Build Coastguard Worker bool Context::userEnabledMathSDP() const {
133*da0073e9SAndroid Build Coastguard Worker   return enabled_mathSDP;
134*da0073e9SAndroid Build Coastguard Worker }
135*da0073e9SAndroid Build Coastguard Worker 
setSDPUseMath(bool e)136*da0073e9SAndroid Build Coastguard Worker void Context::setSDPUseMath(bool e) {
137*da0073e9SAndroid Build Coastguard Worker   enabled_mathSDP = e;
138*da0073e9SAndroid Build Coastguard Worker }
139*da0073e9SAndroid Build Coastguard Worker 
userEnabledCuDNNSDP() const140*da0073e9SAndroid Build Coastguard Worker bool Context::userEnabledCuDNNSDP() const {
141*da0073e9SAndroid Build Coastguard Worker   return enabled_cudnnSDP;
142*da0073e9SAndroid Build Coastguard Worker }
143*da0073e9SAndroid Build Coastguard Worker 
setSDPUseCuDNN(bool e)144*da0073e9SAndroid Build Coastguard Worker void Context::setSDPUseCuDNN(bool e) {
145*da0073e9SAndroid Build Coastguard Worker   enabled_cudnnSDP = e;
146*da0073e9SAndroid Build Coastguard Worker }
147*da0073e9SAndroid Build Coastguard Worker 
148*da0073e9SAndroid Build Coastguard Worker 
149*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
150*da0073e9SAndroid Build Coastguard Worker static const char cublas_config_var_name[] = "CUBLAS_WORKSPACE_CONFIG";
151*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
152*da0073e9SAndroid Build Coastguard Worker static const char* const cublas_deterministic_configs[] = { ":4096:8", ":16:8" };
153*da0073e9SAndroid Build Coastguard Worker 
checkCuBLASConfigDeterministic()154*da0073e9SAndroid Build Coastguard Worker bool Context::checkCuBLASConfigDeterministic() {
155*da0073e9SAndroid Build Coastguard Worker   bool cublas_config_deterministic = true;
156*da0073e9SAndroid Build Coastguard Worker   // If using CUDA 10.2 or greater, need to make sure CuBLAS workspace config
157*da0073e9SAndroid Build Coastguard Worker   // is set to deterministic setting
158*da0073e9SAndroid Build Coastguard Worker   if (hasCUDART() && (versionCUDART() >= 10020)) {
159*da0073e9SAndroid Build Coastguard Worker     char* workspace_config = std::getenv(cublas_config_var_name);
160*da0073e9SAndroid Build Coastguard Worker     cublas_config_deterministic = (workspace_config != nullptr) && (
161*da0073e9SAndroid Build Coastguard Worker       (strcmp(workspace_config, cublas_deterministic_configs[0]) == 0)
162*da0073e9SAndroid Build Coastguard Worker       || (strcmp(workspace_config, cublas_deterministic_configs[1]) == 0)
163*da0073e9SAndroid Build Coastguard Worker     );
164*da0073e9SAndroid Build Coastguard Worker   }
165*da0073e9SAndroid Build Coastguard Worker   return cublas_config_deterministic;
166*da0073e9SAndroid Build Coastguard Worker }
167*da0073e9SAndroid Build Coastguard Worker 
alertCuBLASConfigNotDeterministic() const168*da0073e9SAndroid Build Coastguard Worker void Context::alertCuBLASConfigNotDeterministic() const {
169*da0073e9SAndroid Build Coastguard Worker   static bool cublas_config_deterministic = checkCuBLASConfigDeterministic();
170*da0073e9SAndroid Build Coastguard Worker   if (C10_LIKELY(!deterministicAlgorithms() || cublas_config_deterministic)) {
171*da0073e9SAndroid Build Coastguard Worker     return;
172*da0073e9SAndroid Build Coastguard Worker   }
173*da0073e9SAndroid Build Coastguard Worker 
174*da0073e9SAndroid Build Coastguard Worker   auto msg = c10::str(
175*da0073e9SAndroid Build Coastguard Worker     "Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or ",
176*da0073e9SAndroid Build Coastguard Worker     "`at::Context::setDeterministicAlgorithms(true)`, but this operation is not deterministic because ",
177*da0073e9SAndroid Build Coastguard Worker     "it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this ",
178*da0073e9SAndroid Build Coastguard Worker     "case, you must set an environment variable before running your PyTorch application: ",
179*da0073e9SAndroid Build Coastguard Worker     cublas_config_var_name, "=", cublas_deterministic_configs[0], " or ",
180*da0073e9SAndroid Build Coastguard Worker     cublas_config_var_name, "=", cublas_deterministic_configs[1], ". For more information, go to ",
181*da0073e9SAndroid Build Coastguard Worker     "https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility"
182*da0073e9SAndroid Build Coastguard Worker   );
183*da0073e9SAndroid Build Coastguard Worker 
184*da0073e9SAndroid Build Coastguard Worker   if (deterministicAlgorithmsWarnOnly()) {
185*da0073e9SAndroid Build Coastguard Worker     TORCH_WARN(msg);
186*da0073e9SAndroid Build Coastguard Worker   } else {
187*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(false, msg);
188*da0073e9SAndroid Build Coastguard Worker   }
189*da0073e9SAndroid Build Coastguard Worker }
190*da0073e9SAndroid Build Coastguard Worker 
benchmarkCuDNN() const191*da0073e9SAndroid Build Coastguard Worker bool Context::benchmarkCuDNN() const {
192*da0073e9SAndroid Build Coastguard Worker   return benchmark_cudnn;
193*da0073e9SAndroid Build Coastguard Worker }
194*da0073e9SAndroid Build Coastguard Worker 
setBenchmarkCuDNN(bool b)195*da0073e9SAndroid Build Coastguard Worker void Context::setBenchmarkCuDNN(bool b) {
196*da0073e9SAndroid Build Coastguard Worker   benchmark_cudnn = b;
197*da0073e9SAndroid Build Coastguard Worker }
198*da0073e9SAndroid Build Coastguard Worker 
benchmarkLimitCuDNN() const199*da0073e9SAndroid Build Coastguard Worker int Context::benchmarkLimitCuDNN() const {
200*da0073e9SAndroid Build Coastguard Worker   return benchmark_limit_cudnn;
201*da0073e9SAndroid Build Coastguard Worker }
202*da0073e9SAndroid Build Coastguard Worker 
setBenchmarkLimitCuDNN(int b)203*da0073e9SAndroid Build Coastguard Worker void Context::setBenchmarkLimitCuDNN(int b) {
204*da0073e9SAndroid Build Coastguard Worker   benchmark_limit_cudnn = b;
205*da0073e9SAndroid Build Coastguard Worker }
206*da0073e9SAndroid Build Coastguard Worker 
allowTF32CuBLAS() const207*da0073e9SAndroid Build Coastguard Worker bool Context::allowTF32CuBLAS() const {
208*da0073e9SAndroid Build Coastguard Worker   return float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
209*da0073e9SAndroid Build Coastguard Worker }
210*da0073e9SAndroid Build Coastguard Worker 
setAllowTF32CuBLAS(bool b)211*da0073e9SAndroid Build Coastguard Worker void Context::setAllowTF32CuBLAS(bool b) {
212*da0073e9SAndroid Build Coastguard Worker   float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST;
213*da0073e9SAndroid Build Coastguard Worker }
214*da0073e9SAndroid Build Coastguard Worker 
float32MatmulPrecision() const215*da0073e9SAndroid Build Coastguard Worker Float32MatmulPrecision Context::float32MatmulPrecision() const {
216*da0073e9SAndroid Build Coastguard Worker   return float32_matmul_precision;
217*da0073e9SAndroid Build Coastguard Worker }
218*da0073e9SAndroid Build Coastguard Worker 
setFloat32MatmulPrecision(Float32MatmulPrecision p)219*da0073e9SAndroid Build Coastguard Worker void Context::setFloat32MatmulPrecision(Float32MatmulPrecision p) {
220*da0073e9SAndroid Build Coastguard Worker   float32_matmul_precision = p;
221*da0073e9SAndroid Build Coastguard Worker }
222*da0073e9SAndroid Build Coastguard Worker 
setFloat32MatmulPrecision(const std::string & s)223*da0073e9SAndroid Build Coastguard Worker void Context::setFloat32MatmulPrecision(const std::string &s) {
224*da0073e9SAndroid Build Coastguard Worker   auto match = [this](const std::string & s_) {
225*da0073e9SAndroid Build Coastguard Worker     // TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
226*da0073e9SAndroid Build Coastguard Worker     if (s_ == "highest") {
227*da0073e9SAndroid Build Coastguard Worker       float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
228*da0073e9SAndroid Build Coastguard Worker       return true;
229*da0073e9SAndroid Build Coastguard Worker     } else if (s_ == "high") {
230*da0073e9SAndroid Build Coastguard Worker       float32_matmul_precision = at::Float32MatmulPrecision::HIGH;
231*da0073e9SAndroid Build Coastguard Worker       return true;
232*da0073e9SAndroid Build Coastguard Worker     } else if (s_ == "medium") {
233*da0073e9SAndroid Build Coastguard Worker       float32_matmul_precision = at::Float32MatmulPrecision::MEDIUM;
234*da0073e9SAndroid Build Coastguard Worker       return true;
235*da0073e9SAndroid Build Coastguard Worker     }
236*da0073e9SAndroid Build Coastguard Worker     return false;
237*da0073e9SAndroid Build Coastguard Worker   };
238*da0073e9SAndroid Build Coastguard Worker   if (match(s)) { return; }
239*da0073e9SAndroid Build Coastguard Worker   std::string sl;
240*da0073e9SAndroid Build Coastguard Worker   std::transform(s.begin(), s.end(), sl.begin(),
241*da0073e9SAndroid Build Coastguard Worker                  [](unsigned char c) -> unsigned char { return std::tolower(c); });
242*da0073e9SAndroid Build Coastguard Worker   if (match(sl)) { return; }
243*da0073e9SAndroid Build Coastguard Worker   TORCH_WARN(s, " is not one of 'highest', 'high', or 'medium'; the current"
244*da0073e9SAndroid Build Coastguard Worker     "setFloat32MatmulPrecision call has no effect.");
245*da0073e9SAndroid Build Coastguard Worker }
246*da0073e9SAndroid Build Coastguard Worker 
linalgPreferredBackend() const247*da0073e9SAndroid Build Coastguard Worker at::LinalgBackend Context::linalgPreferredBackend() const {
248*da0073e9SAndroid Build Coastguard Worker   return linalg_preferred_backend;
249*da0073e9SAndroid Build Coastguard Worker }
250*da0073e9SAndroid Build Coastguard Worker 
setLinalgPreferredBackend(at::LinalgBackend b)251*da0073e9SAndroid Build Coastguard Worker void Context::setLinalgPreferredBackend(at::LinalgBackend b) {
252*da0073e9SAndroid Build Coastguard Worker   linalg_preferred_backend = b;
253*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK((b != at::LinalgBackend::Cusolver) || hasCuSOLVER(),
254*da0073e9SAndroid Build Coastguard Worker       "Cannot set preferred backend to cuSOLVER if PyTorch has not been compiled with cuSOLVER.");
255*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK((b != at::LinalgBackend::Magma) || hasMAGMA(),
256*da0073e9SAndroid Build Coastguard Worker       "Cannot set preferred backend to MAGMA if PyTorch has not been compiled with MAGMA.");
257*da0073e9SAndroid Build Coastguard Worker   if (b != at::LinalgBackend::Default) {
258*da0073e9SAndroid Build Coastguard Worker     TORCH_WARN_ONCE(
259*da0073e9SAndroid Build Coastguard Worker       "torch.backends.cuda.preferred_linalg_library is an experimental feature. "
260*da0073e9SAndroid Build Coastguard Worker       "If you see any error or unexpected behavior when this flag is set "
261*da0073e9SAndroid Build Coastguard Worker       "please file an issue on GitHub."
262*da0073e9SAndroid Build Coastguard Worker     );
263*da0073e9SAndroid Build Coastguard Worker   }
264*da0073e9SAndroid Build Coastguard Worker }
265*da0073e9SAndroid Build Coastguard Worker 
blasPreferredBackend()266*da0073e9SAndroid Build Coastguard Worker at::BlasBackend Context::blasPreferredBackend() {
267*da0073e9SAndroid Build Coastguard Worker #ifdef USE_ROCM
268*da0073e9SAndroid Build Coastguard Worker   if (blas_preferred_backend == at::BlasBackend::Cublaslt) {
269*da0073e9SAndroid Build Coastguard Worker     static const bool hipblaslt_unsupported = []() {
270*da0073e9SAndroid Build Coastguard Worker       static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
271*da0073e9SAndroid Build Coastguard Worker       for (auto index = 0; index < at::getNumGPUs(); index++) {
272*da0073e9SAndroid Build Coastguard Worker         if (!detail::getCUDAHooks().isGPUArch(index, archs)) {
273*da0073e9SAndroid Build Coastguard Worker           TORCH_WARN_ONCE(
274*da0073e9SAndroid Build Coastguard Worker             "Attempting to use hipBLASLt on an unsupported architecture! "
275*da0073e9SAndroid Build Coastguard Worker             "Overriding blas backend to hipblas");
276*da0073e9SAndroid Build Coastguard Worker           return true;
277*da0073e9SAndroid Build Coastguard Worker         }
278*da0073e9SAndroid Build Coastguard Worker       }
279*da0073e9SAndroid Build Coastguard Worker       return false;
280*da0073e9SAndroid Build Coastguard Worker     }();
281*da0073e9SAndroid Build Coastguard Worker     if (hipblaslt_unsupported) blas_preferred_backend = at::BlasBackend::Cublas;
282*da0073e9SAndroid Build Coastguard Worker   }
283*da0073e9SAndroid Build Coastguard Worker #endif
284*da0073e9SAndroid Build Coastguard Worker   return blas_preferred_backend;
285*da0073e9SAndroid Build Coastguard Worker }
286*da0073e9SAndroid Build Coastguard Worker 
setBlasPreferredBackend(at::BlasBackend b)287*da0073e9SAndroid Build Coastguard Worker void Context::setBlasPreferredBackend(at::BlasBackend b) {
288*da0073e9SAndroid Build Coastguard Worker #ifdef _MSC_VER
289*da0073e9SAndroid Build Coastguard Worker   TORCH_WARN_ONCE(
290*da0073e9SAndroid Build Coastguard Worker     "torch.backends.cuda.preferred_blas_library is an experimental feature. "
291*da0073e9SAndroid Build Coastguard Worker     "It is not supported on Windows."
292*da0073e9SAndroid Build Coastguard Worker   );
293*da0073e9SAndroid Build Coastguard Worker #else
294*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK((b != at::BlasBackend::Cublaslt) || hasCuBLASLt(),
295*da0073e9SAndroid Build Coastguard Worker       "Cannot set preferred backend to cuBLASLt if PyTorch has not been compiled with cuBLASLt.");
296*da0073e9SAndroid Build Coastguard Worker   if (b != at::BlasBackend::Cublas) {
297*da0073e9SAndroid Build Coastguard Worker     TORCH_WARN_ONCE(
298*da0073e9SAndroid Build Coastguard Worker       "torch.backends.cuda.preferred_blas_library is an experimental feature. "
299*da0073e9SAndroid Build Coastguard Worker       "If you see any error or unexpected behavior when this flag is set "
300*da0073e9SAndroid Build Coastguard Worker       "please file an issue on GitHub."
301*da0073e9SAndroid Build Coastguard Worker     );
302*da0073e9SAndroid Build Coastguard Worker   }
303*da0073e9SAndroid Build Coastguard Worker   blas_preferred_backend = b;
304*da0073e9SAndroid Build Coastguard Worker #endif
305*da0073e9SAndroid Build Coastguard Worker }
306*da0073e9SAndroid Build Coastguard Worker 
allowFP16ReductionCuBLAS() const307*da0073e9SAndroid Build Coastguard Worker bool Context::allowFP16ReductionCuBLAS() const {
308*da0073e9SAndroid Build Coastguard Worker   return allow_fp16_reduction_cublas;
309*da0073e9SAndroid Build Coastguard Worker }
310*da0073e9SAndroid Build Coastguard Worker 
setAllowFP16ReductionCuBLAS(bool b)311*da0073e9SAndroid Build Coastguard Worker void Context::setAllowFP16ReductionCuBLAS(bool b) {
312*da0073e9SAndroid Build Coastguard Worker   allow_fp16_reduction_cublas = b;
313*da0073e9SAndroid Build Coastguard Worker }
314*da0073e9SAndroid Build Coastguard Worker 
allowBF16ReductionCuBLAS() const315*da0073e9SAndroid Build Coastguard Worker bool Context::allowBF16ReductionCuBLAS() const {
316*da0073e9SAndroid Build Coastguard Worker   return allow_bf16_reduction_cublas;
317*da0073e9SAndroid Build Coastguard Worker }
318*da0073e9SAndroid Build Coastguard Worker 
setAllowBF16ReductionCuBLAS(bool b)319*da0073e9SAndroid Build Coastguard Worker void Context::setAllowBF16ReductionCuBLAS(bool b) {
320*da0073e9SAndroid Build Coastguard Worker   allow_bf16_reduction_cublas = b;
321*da0073e9SAndroid Build Coastguard Worker }
322*da0073e9SAndroid Build Coastguard Worker 
323*da0073e9SAndroid Build Coastguard Worker 
hasMKL()324*da0073e9SAndroid Build Coastguard Worker bool Context::hasMKL() {
325*da0073e9SAndroid Build Coastguard Worker #if AT_MKL_ENABLED()
326*da0073e9SAndroid Build Coastguard Worker   return true;
327*da0073e9SAndroid Build Coastguard Worker #else
328*da0073e9SAndroid Build Coastguard Worker   return false;
329*da0073e9SAndroid Build Coastguard Worker #endif
330*da0073e9SAndroid Build Coastguard Worker }
331*da0073e9SAndroid Build Coastguard Worker 
hasMKLDNN()332*da0073e9SAndroid Build Coastguard Worker bool Context::hasMKLDNN() {
333*da0073e9SAndroid Build Coastguard Worker #if AT_MKLDNN_ENABLED()
334*da0073e9SAndroid Build Coastguard Worker   return true;
335*da0073e9SAndroid Build Coastguard Worker #else
336*da0073e9SAndroid Build Coastguard Worker   return false;
337*da0073e9SAndroid Build Coastguard Worker #endif
338*da0073e9SAndroid Build Coastguard Worker }
339*da0073e9SAndroid Build Coastguard Worker 
hasOpenMP()340*da0073e9SAndroid Build Coastguard Worker bool Context::hasOpenMP() {
341*da0073e9SAndroid Build Coastguard Worker #ifdef _OPENMP
342*da0073e9SAndroid Build Coastguard Worker   return true;
343*da0073e9SAndroid Build Coastguard Worker #else
344*da0073e9SAndroid Build Coastguard Worker   return false;
345*da0073e9SAndroid Build Coastguard Worker #endif
346*da0073e9SAndroid Build Coastguard Worker }
347*da0073e9SAndroid Build Coastguard Worker 
hasLAPACK()348*da0073e9SAndroid Build Coastguard Worker bool Context::hasLAPACK() {
349*da0073e9SAndroid Build Coastguard Worker #if AT_BUILD_WITH_LAPACK()
350*da0073e9SAndroid Build Coastguard Worker   return true;
351*da0073e9SAndroid Build Coastguard Worker #else
352*da0073e9SAndroid Build Coastguard Worker   return false;
353*da0073e9SAndroid Build Coastguard Worker #endif
354*da0073e9SAndroid Build Coastguard Worker }
355*da0073e9SAndroid Build Coastguard Worker 
qEngine() const356*da0073e9SAndroid Build Coastguard Worker at::QEngine Context::qEngine() const {
357*da0073e9SAndroid Build Coastguard Worker   static auto _quantized_engine = []() {
358*da0073e9SAndroid Build Coastguard Worker     at::QEngine qengine = at::kNoQEngine;
359*da0073e9SAndroid Build Coastguard Worker #if defined(C10_MOBILE) && defined(USE_PYTORCH_QNNPACK)
360*da0073e9SAndroid Build Coastguard Worker     qengine = at::kQNNPACK;
361*da0073e9SAndroid Build Coastguard Worker #endif
362*da0073e9SAndroid Build Coastguard Worker 
363*da0073e9SAndroid Build Coastguard Worker #if AT_MKLDNN_ENABLED()
364*da0073e9SAndroid Build Coastguard Worker     qengine = at::kONEDNN;
365*da0073e9SAndroid Build Coastguard Worker #endif
366*da0073e9SAndroid Build Coastguard Worker 
367*da0073e9SAndroid Build Coastguard Worker #ifdef USE_FBGEMM
368*da0073e9SAndroid Build Coastguard Worker     if (fbgemm::fbgemmSupportedCPU()) {
369*da0073e9SAndroid Build Coastguard Worker       /* X86 is enabled if and only if fbgemm is available.
370*da0073e9SAndroid Build Coastguard Worker        * It combines goodness of fbgemm and onednn by dispatching.
371*da0073e9SAndroid Build Coastguard Worker        * If onednn not available, always dispatch to fbgemm.
372*da0073e9SAndroid Build Coastguard Worker        * Make it default qengine for X86 CPU platforms.
373*da0073e9SAndroid Build Coastguard Worker       */
374*da0073e9SAndroid Build Coastguard Worker       qengine = at::kX86;
375*da0073e9SAndroid Build Coastguard Worker     }
376*da0073e9SAndroid Build Coastguard Worker #endif
377*da0073e9SAndroid Build Coastguard Worker     return qengine;
378*da0073e9SAndroid Build Coastguard Worker   }();
379*da0073e9SAndroid Build Coastguard Worker   return quantized_engine.value_or(_quantized_engine);
380*da0073e9SAndroid Build Coastguard Worker }
381*da0073e9SAndroid Build Coastguard Worker 
setQEngine(at::QEngine e)382*da0073e9SAndroid Build Coastguard Worker void Context::setQEngine(at::QEngine e) {
383*da0073e9SAndroid Build Coastguard Worker   const auto& qengines = supportedQEngines();
384*da0073e9SAndroid Build Coastguard Worker   if (std::find(qengines.begin(), qengines.end(), e) != qengines.end()) {
385*da0073e9SAndroid Build Coastguard Worker     quantized_engine = e;
386*da0073e9SAndroid Build Coastguard Worker     return;
387*da0073e9SAndroid Build Coastguard Worker   }
388*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(false, "quantized engine ", toString(e), " is not supported");
389*da0073e9SAndroid Build Coastguard Worker }
390*da0073e9SAndroid Build Coastguard Worker 
supportedQEngines()391*da0073e9SAndroid Build Coastguard Worker const std::vector<at::QEngine>& Context::supportedQEngines() {
392*da0073e9SAndroid Build Coastguard Worker   static auto supported_qengines = []() {
393*da0073e9SAndroid Build Coastguard Worker     std::vector<at::QEngine> engines = {};
394*da0073e9SAndroid Build Coastguard Worker     // Engines are listed in priority order: later one wins
395*da0073e9SAndroid Build Coastguard Worker     // By default we prefer FBGEMM if we're running on server side
396*da0073e9SAndroid Build Coastguard Worker     // QNNPACK on server side has some issue, so we disable it by default.
397*da0073e9SAndroid Build Coastguard Worker #ifdef C10_MOBILE
398*da0073e9SAndroid Build Coastguard Worker     engines.push_back(at::kNoQEngine);
399*da0073e9SAndroid Build Coastguard Worker #ifdef USE_PYTORCH_QNNPACK
400*da0073e9SAndroid Build Coastguard Worker     engines.push_back(at::kQNNPACK);
401*da0073e9SAndroid Build Coastguard Worker #endif
402*da0073e9SAndroid Build Coastguard Worker #else  // C10_MOBILE
403*da0073e9SAndroid Build Coastguard Worker #ifdef USE_PYTORCH_QNNPACK
404*da0073e9SAndroid Build Coastguard Worker     engines.push_back(at::kQNNPACK);
405*da0073e9SAndroid Build Coastguard Worker #endif
406*da0073e9SAndroid Build Coastguard Worker     engines.push_back(at::kNoQEngine);
407*da0073e9SAndroid Build Coastguard Worker #endif // C10_MOBILE
408*da0073e9SAndroid Build Coastguard Worker 
409*da0073e9SAndroid Build Coastguard Worker #if AT_MKLDNN_ENABLED()
410*da0073e9SAndroid Build Coastguard Worker     engines.push_back(at::kONEDNN);
411*da0073e9SAndroid Build Coastguard Worker #endif
412*da0073e9SAndroid Build Coastguard Worker 
413*da0073e9SAndroid Build Coastguard Worker #ifdef USE_FBGEMM
414*da0073e9SAndroid Build Coastguard Worker     if (fbgemm::fbgemmSupportedCPU()) {
415*da0073e9SAndroid Build Coastguard Worker       engines.push_back(at::kX86);
416*da0073e9SAndroid Build Coastguard Worker       // The X86 qengine is available if and only if FBGEMM is available
417*da0073e9SAndroid Build Coastguard Worker       engines.push_back(at::kFBGEMM);
418*da0073e9SAndroid Build Coastguard Worker     }
419*da0073e9SAndroid Build Coastguard Worker #endif
420*da0073e9SAndroid Build Coastguard Worker 
421*da0073e9SAndroid Build Coastguard Worker     return engines;
422*da0073e9SAndroid Build Coastguard Worker   }();
423*da0073e9SAndroid Build Coastguard Worker   return supported_qengines;
424*da0073e9SAndroid Build Coastguard Worker }
425*da0073e9SAndroid Build Coastguard Worker 
isXNNPACKAvailable()426*da0073e9SAndroid Build Coastguard Worker bool Context::isXNNPACKAvailable() {
427*da0073e9SAndroid Build Coastguard Worker #ifdef USE_XNNPACK
428*da0073e9SAndroid Build Coastguard Worker   return true;
429*da0073e9SAndroid Build Coastguard Worker #else
430*da0073e9SAndroid Build Coastguard Worker   return false;
431*da0073e9SAndroid Build Coastguard Worker #endif
432*da0073e9SAndroid Build Coastguard Worker }
433*da0073e9SAndroid Build Coastguard Worker 
setCheckSparseTensorInvariants(bool e)434*da0073e9SAndroid Build Coastguard Worker void Context::setCheckSparseTensorInvariants(bool e) {
435*da0073e9SAndroid Build Coastguard Worker   enable_sparse_tensor_invariant_checks = e;
436*da0073e9SAndroid Build Coastguard Worker }
437*da0073e9SAndroid Build Coastguard Worker 
checkSparseTensorInvariants() const438*da0073e9SAndroid Build Coastguard Worker bool Context::checkSparseTensorInvariants() const {
439*da0073e9SAndroid Build Coastguard Worker   return enable_sparse_tensor_invariant_checks;
440*da0073e9SAndroid Build Coastguard Worker }
441*da0073e9SAndroid Build Coastguard Worker 
releaseWeightsWhenPrepacking() const442*da0073e9SAndroid Build Coastguard Worker bool Context::releaseWeightsWhenPrepacking() const {
443*da0073e9SAndroid Build Coastguard Worker   return release_original_weights;
444*da0073e9SAndroid Build Coastguard Worker }
445*da0073e9SAndroid Build Coastguard Worker 
setReleaseWeightsWhenPrepacking(bool e)446*da0073e9SAndroid Build Coastguard Worker void Context::setReleaseWeightsWhenPrepacking(bool e) {
447*da0073e9SAndroid Build Coastguard Worker   release_original_weights = e;
448*da0073e9SAndroid Build Coastguard Worker }
449*da0073e9SAndroid Build Coastguard Worker 
setFlushDenormal(bool on)450*da0073e9SAndroid Build Coastguard Worker bool Context::setFlushDenormal(bool on) {
451*da0073e9SAndroid Build Coastguard Worker   return at::cpu::set_flush_denormal(on);
452*da0073e9SAndroid Build Coastguard Worker }
453*da0073e9SAndroid Build Coastguard Worker 
getCPUAllocator()454*da0073e9SAndroid Build Coastguard Worker Allocator* getCPUAllocator() {
455*da0073e9SAndroid Build Coastguard Worker   return c10::GetCPUAllocator();
456*da0073e9SAndroid Build Coastguard Worker }
457*da0073e9SAndroid Build Coastguard Worker 
458*da0073e9SAndroid Build Coastguard Worker // override_allow_tf32_flag = true
459*da0073e9SAndroid Build Coastguard Worker //    means the allow_tf32 flags are overrided and tf32 is force disabled
460*da0073e9SAndroid Build Coastguard Worker // override_allow_tf32_flag = false
461*da0073e9SAndroid Build Coastguard Worker //    means the original allow_tf32 flags are followed
462*da0073e9SAndroid Build Coastguard Worker thread_local bool override_allow_tf32_flag = false;
463*da0073e9SAndroid Build Coastguard Worker 
NoTF32Guard()464*da0073e9SAndroid Build Coastguard Worker NoTF32Guard::NoTF32Guard() {
465*da0073e9SAndroid Build Coastguard Worker   if (!override_allow_tf32_flag) {
466*da0073e9SAndroid Build Coastguard Worker     changed = true;
467*da0073e9SAndroid Build Coastguard Worker     override_allow_tf32_flag = true;
468*da0073e9SAndroid Build Coastguard Worker   }
469*da0073e9SAndroid Build Coastguard Worker }
470*da0073e9SAndroid Build Coastguard Worker 
~NoTF32Guard()471*da0073e9SAndroid Build Coastguard Worker NoTF32Guard::~NoTF32Guard() {
472*da0073e9SAndroid Build Coastguard Worker   if (changed) {
473*da0073e9SAndroid Build Coastguard Worker     override_allow_tf32_flag = false;
474*da0073e9SAndroid Build Coastguard Worker   }
475*da0073e9SAndroid Build Coastguard Worker }
476*da0073e9SAndroid Build Coastguard Worker 
should_disable_tf32()477*da0073e9SAndroid Build Coastguard Worker bool NoTF32Guard::should_disable_tf32() {
478*da0073e9SAndroid Build Coastguard Worker   return override_allow_tf32_flag;
479*da0073e9SAndroid Build Coastguard Worker }
480*da0073e9SAndroid Build Coastguard Worker 
481*da0073e9SAndroid Build Coastguard Worker // Ops can query this flag to know they are in the backward pass.
482*da0073e9SAndroid Build Coastguard Worker // This information can be used, for example, to select implementations
483*da0073e9SAndroid Build Coastguard Worker // with different numerical or performance characteristics.
484*da0073e9SAndroid Build Coastguard Worker // See https://pytorch.org/docs/stable/notes/numerical_accuracy.html for details.
485*da0073e9SAndroid Build Coastguard Worker thread_local bool rocm_is_backward_pass;
486*da0073e9SAndroid Build Coastguard Worker 
ROCmBackwardPassGuard()487*da0073e9SAndroid Build Coastguard Worker ROCmBackwardPassGuard::ROCmBackwardPassGuard() {
488*da0073e9SAndroid Build Coastguard Worker   rocm_is_backward_pass = true;
489*da0073e9SAndroid Build Coastguard Worker }
490*da0073e9SAndroid Build Coastguard Worker 
~ROCmBackwardPassGuard()491*da0073e9SAndroid Build Coastguard Worker ROCmBackwardPassGuard::~ROCmBackwardPassGuard() {
492*da0073e9SAndroid Build Coastguard Worker   rocm_is_backward_pass = false;
493*da0073e9SAndroid Build Coastguard Worker }
494*da0073e9SAndroid Build Coastguard Worker 
is_backward_pass()495*da0073e9SAndroid Build Coastguard Worker bool ROCmBackwardPassGuard::is_backward_pass() {
496*da0073e9SAndroid Build Coastguard Worker   return rocm_is_backward_pass;
497*da0073e9SAndroid Build Coastguard Worker }
498*da0073e9SAndroid Build Coastguard Worker 
areVmapFallbackWarningsEnabled() const499*da0073e9SAndroid Build Coastguard Worker bool Context::areVmapFallbackWarningsEnabled() const {
500*da0073e9SAndroid Build Coastguard Worker   return display_vmap_fallback_warnings_;
501*da0073e9SAndroid Build Coastguard Worker }
502*da0073e9SAndroid Build Coastguard Worker 
setDisplayVmapFallbackWarnings(bool enabled)503*da0073e9SAndroid Build Coastguard Worker void Context::setDisplayVmapFallbackWarnings(bool enabled) {
504*da0073e9SAndroid Build Coastguard Worker   display_vmap_fallback_warnings_ = enabled;
505*da0073e9SAndroid Build Coastguard Worker }
506*da0073e9SAndroid Build Coastguard Worker 
setDefaultMobileCPUAllocator()507*da0073e9SAndroid Build Coastguard Worker void Context::setDefaultMobileCPUAllocator() {
508*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(prev_allocator_ptr_ == nullptr,
509*da0073e9SAndroid Build Coastguard Worker       "Already within the scope of another non-default cpu allocator."
510*da0073e9SAndroid Build Coastguard Worker       "Cannot set another allocator.");
511*da0073e9SAndroid Build Coastguard Worker   // Setting the priority high to make sure no other allocator gets used instead of this.
512*da0073e9SAndroid Build Coastguard Worker   prev_allocator_ptr_ = c10::GetCPUAllocator();
513*da0073e9SAndroid Build Coastguard Worker   c10::SetCPUAllocator(c10::GetDefaultMobileCPUAllocator(), /*priority*/ 100);
514*da0073e9SAndroid Build Coastguard Worker }
515*da0073e9SAndroid Build Coastguard Worker 
unsetDefaultMobileCPUAllocator()516*da0073e9SAndroid Build Coastguard Worker void Context::unsetDefaultMobileCPUAllocator() {
517*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(prev_allocator_ptr_ != nullptr,
518*da0073e9SAndroid Build Coastguard Worker       "setDefaultMobileCPUAllocator must have been called "
519*da0073e9SAndroid Build Coastguard Worker       "before unsetDefaultMobileCPUAllocator.");
520*da0073e9SAndroid Build Coastguard Worker   // Setting the priority high to make sure no other allocator gets used instead of this.
521*da0073e9SAndroid Build Coastguard Worker   c10::SetCPUAllocator(prev_allocator_ptr_ , /*priority*/ 100);
522*da0073e9SAndroid Build Coastguard Worker   prev_allocator_ptr_ = nullptr;
523*da0073e9SAndroid Build Coastguard Worker }
524*da0073e9SAndroid Build Coastguard Worker 
allowFP16ReductionCPU() const525*da0073e9SAndroid Build Coastguard Worker bool Context::allowFP16ReductionCPU() const {
526*da0073e9SAndroid Build Coastguard Worker   return allow_fp16_reduction_cpu;
527*da0073e9SAndroid Build Coastguard Worker }
528*da0073e9SAndroid Build Coastguard Worker 
setAllowFP16ReductionCPU(bool b)529*da0073e9SAndroid Build Coastguard Worker void Context::setAllowFP16ReductionCPU(bool b) {
530*da0073e9SAndroid Build Coastguard Worker   if ( b && !allow_fp16_reduction_cpu) {
531*da0073e9SAndroid Build Coastguard Worker     // Check that CPU supports fp16 reductions
532*da0073e9SAndroid Build Coastguard Worker #if defined(__aarch64__) && !defined(C10_MOBILE)
533*da0073e9SAndroid Build Coastguard Worker     if (!cpuinfo_initialize() || !cpuinfo_has_arm_fp16_arith())
534*da0073e9SAndroid Build Coastguard Worker #else
535*da0073e9SAndroid Build Coastguard Worker     if (true)
536*da0073e9SAndroid Build Coastguard Worker #endif
537*da0073e9SAndroid Build Coastguard Worker       throw std::runtime_error("Float16 arithmetic is not supported by the CPU!");
538*da0073e9SAndroid Build Coastguard Worker   }
539*da0073e9SAndroid Build Coastguard Worker   allow_fp16_reduction_cpu = b;
540*da0073e9SAndroid Build Coastguard Worker }
541*da0073e9SAndroid Build Coastguard Worker } // namespace at
542