xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ThreadLocalState.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <ATen/ThreadLocalState.h>
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/autocast_mode.h>
5*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/grad_mode.h>
6*da0073e9SAndroid Build Coastguard Worker #endif
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker #include <ATen/record_function.h>
9*da0073e9SAndroid Build Coastguard Worker #include <ATen/SavedTensorHooks.h>
10*da0073e9SAndroid Build Coastguard Worker #include <ATen/FunctionalTensorWrapper.h>
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker namespace at {
13*da0073e9SAndroid Build Coastguard Worker 
ThreadLocalState()14*da0073e9SAndroid Build Coastguard Worker ThreadLocalState::ThreadLocalState()
15*da0073e9SAndroid Build Coastguard Worker     : dispatch_key_(c10::impl::tls_local_dispatch_key_set()),
16*da0073e9SAndroid Build Coastguard Worker       debug_info_(c10::ThreadLocalDebugInfo::current()),
17*da0073e9SAndroid Build Coastguard Worker       rf_tls_(at::get_record_function_tls_()), functorch_tls_(functorch::getCopyOfFuncTorchTLS()),
18*da0073e9SAndroid Build Coastguard Worker       autograd_tls_(c10::AutogradState::get_tls_state()),
19*da0073e9SAndroid Build Coastguard Worker       torch_dispatch_mode_state_(c10::impl::TorchDispatchModeTLS::get_state()), python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()),
20*da0073e9SAndroid Build Coastguard Worker       python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()),
21*da0073e9SAndroid Build Coastguard Worker       saved_tensors_default_hooks_state_(at::SavedTensorDefaultHooks::get_tls_state()), functionalization_reapply_views_state_(at::functionalization::impl::getFunctionalizationReapplyViewsTLS()),
22*da0073e9SAndroid Build Coastguard Worker       saved_objects_(at::impl::ThreadLocalPythonObjects::get_state()) {
23*da0073e9SAndroid Build Coastguard Worker #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
24*da0073e9SAndroid Build Coastguard Worker   for(uint8_t i=0; i<autocast_dtypes_.size(); i++) {
25*da0073e9SAndroid Build Coastguard Worker      autocast_dtypes_[i] = at::autocast::get_autocast_dtype(static_cast<at::DeviceType>(i));
26*da0073e9SAndroid Build Coastguard Worker   }
27*da0073e9SAndroid Build Coastguard Worker #endif
28*da0073e9SAndroid Build Coastguard Worker }
29*da0073e9SAndroid Build Coastguard Worker 
set_grad_mode(bool enabled)30*da0073e9SAndroid Build Coastguard Worker void ThreadLocalState::set_grad_mode(bool enabled) {
31*da0073e9SAndroid Build Coastguard Worker   autograd_tls_.set_grad_mode(enabled);
32*da0073e9SAndroid Build Coastguard Worker }
33*da0073e9SAndroid Build Coastguard Worker 
set_multithreading_enabled(bool enabled)34*da0073e9SAndroid Build Coastguard Worker void ThreadLocalState::set_multithreading_enabled(bool enabled) {
35*da0073e9SAndroid Build Coastguard Worker   autograd_tls_.set_multithreading_enabled(enabled);
36*da0073e9SAndroid Build Coastguard Worker }
37*da0073e9SAndroid Build Coastguard Worker 
38*da0073e9SAndroid Build Coastguard Worker /* static */
setThreadLocalState(const ThreadLocalState & state)39*da0073e9SAndroid Build Coastguard Worker void ThreadLocalState::setThreadLocalState(
40*da0073e9SAndroid Build Coastguard Worker     const ThreadLocalState& state) {
41*da0073e9SAndroid Build Coastguard Worker   // Note that setting the InferenceMode TLS in this function is ONLY ok because we always
42*da0073e9SAndroid Build Coastguard Worker   // restore the dispatch key set TLS at the same time.
43*da0073e9SAndroid Build Coastguard Worker   c10::AutogradState::set_tls_state(state.autograd_tls_);
44*da0073e9SAndroid Build Coastguard Worker 
45*da0073e9SAndroid Build Coastguard Worker   c10::impl::TorchDispatchModeTLS::set_state(state.torch_dispatch_mode_state_);
46*da0073e9SAndroid Build Coastguard Worker 
47*da0073e9SAndroid Build Coastguard Worker   at::impl::PythonTorchFunctionTLS::set_state(state.python_torch_function_state_);
48*da0073e9SAndroid Build Coastguard Worker 
49*da0073e9SAndroid Build Coastguard Worker   at::set_record_function_tls_(state.rf_tls_);
50*da0073e9SAndroid Build Coastguard Worker 
51*da0073e9SAndroid Build Coastguard Worker   at::SavedTensorDefaultHooks::set_tls_state(state.saved_tensors_default_hooks_state_);
52*da0073e9SAndroid Build Coastguard Worker 
53*da0073e9SAndroid Build Coastguard Worker   c10::impl::PythonDispatcherTLS::set_state(state.python_dispatcher_state_);
54*da0073e9SAndroid Build Coastguard Worker 
55*da0073e9SAndroid Build Coastguard Worker   c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_);
56*da0073e9SAndroid Build Coastguard Worker 
57*da0073e9SAndroid Build Coastguard Worker   c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_);
58*da0073e9SAndroid Build Coastguard Worker 
59*da0073e9SAndroid Build Coastguard Worker   functorch::setFuncTorchTLS(state.functorch_tls_);
60*da0073e9SAndroid Build Coastguard Worker 
61*da0073e9SAndroid Build Coastguard Worker   at::functionalization::impl::setFunctionalizationReapplyViewsTLS(state.functionalization_reapply_views_state_);
62*da0073e9SAndroid Build Coastguard Worker 
63*da0073e9SAndroid Build Coastguard Worker   at::impl::ThreadLocalPythonObjects::set_state(state.saved_objects_);
64*da0073e9SAndroid Build Coastguard Worker #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
65*da0073e9SAndroid Build Coastguard Worker   for(uint8_t i=0; i<state.autocast_dtypes_.size(); i++) {
66*da0073e9SAndroid Build Coastguard Worker      at::autocast::set_autocast_dtype(static_cast<at::DeviceType>(i), state.autocast_dtypes_[i]);
67*da0073e9SAndroid Build Coastguard Worker   }
68*da0073e9SAndroid Build Coastguard Worker #endif
69*da0073e9SAndroid Build Coastguard Worker }
70*da0073e9SAndroid Build Coastguard Worker 
71*da0073e9SAndroid Build Coastguard Worker } // namespace at
72