1*da0073e9SAndroid Build Coastguard Worker #include <ATen/ThreadLocalState.h>
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/autocast_mode.h>
5*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/grad_mode.h>
6*da0073e9SAndroid Build Coastguard Worker #endif
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker #include <ATen/record_function.h>
9*da0073e9SAndroid Build Coastguard Worker #include <ATen/SavedTensorHooks.h>
10*da0073e9SAndroid Build Coastguard Worker #include <ATen/FunctionalTensorWrapper.h>
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker namespace at {
13*da0073e9SAndroid Build Coastguard Worker
ThreadLocalState()14*da0073e9SAndroid Build Coastguard Worker ThreadLocalState::ThreadLocalState()
15*da0073e9SAndroid Build Coastguard Worker : dispatch_key_(c10::impl::tls_local_dispatch_key_set()),
16*da0073e9SAndroid Build Coastguard Worker debug_info_(c10::ThreadLocalDebugInfo::current()),
17*da0073e9SAndroid Build Coastguard Worker rf_tls_(at::get_record_function_tls_()), functorch_tls_(functorch::getCopyOfFuncTorchTLS()),
18*da0073e9SAndroid Build Coastguard Worker autograd_tls_(c10::AutogradState::get_tls_state()),
19*da0073e9SAndroid Build Coastguard Worker torch_dispatch_mode_state_(c10::impl::TorchDispatchModeTLS::get_state()), python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()),
20*da0073e9SAndroid Build Coastguard Worker python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()),
21*da0073e9SAndroid Build Coastguard Worker saved_tensors_default_hooks_state_(at::SavedTensorDefaultHooks::get_tls_state()), functionalization_reapply_views_state_(at::functionalization::impl::getFunctionalizationReapplyViewsTLS()),
22*da0073e9SAndroid Build Coastguard Worker saved_objects_(at::impl::ThreadLocalPythonObjects::get_state()) {
23*da0073e9SAndroid Build Coastguard Worker #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
24*da0073e9SAndroid Build Coastguard Worker for(uint8_t i=0; i<autocast_dtypes_.size(); i++) {
25*da0073e9SAndroid Build Coastguard Worker autocast_dtypes_[i] = at::autocast::get_autocast_dtype(static_cast<at::DeviceType>(i));
26*da0073e9SAndroid Build Coastguard Worker }
27*da0073e9SAndroid Build Coastguard Worker #endif
28*da0073e9SAndroid Build Coastguard Worker }
29*da0073e9SAndroid Build Coastguard Worker
set_grad_mode(bool enabled)30*da0073e9SAndroid Build Coastguard Worker void ThreadLocalState::set_grad_mode(bool enabled) {
31*da0073e9SAndroid Build Coastguard Worker autograd_tls_.set_grad_mode(enabled);
32*da0073e9SAndroid Build Coastguard Worker }
33*da0073e9SAndroid Build Coastguard Worker
set_multithreading_enabled(bool enabled)34*da0073e9SAndroid Build Coastguard Worker void ThreadLocalState::set_multithreading_enabled(bool enabled) {
35*da0073e9SAndroid Build Coastguard Worker autograd_tls_.set_multithreading_enabled(enabled);
36*da0073e9SAndroid Build Coastguard Worker }
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker /* static */
setThreadLocalState(const ThreadLocalState & state)39*da0073e9SAndroid Build Coastguard Worker void ThreadLocalState::setThreadLocalState(
40*da0073e9SAndroid Build Coastguard Worker const ThreadLocalState& state) {
41*da0073e9SAndroid Build Coastguard Worker // Note that setting the InferenceMode TLS in this function is ONLY ok because we always
42*da0073e9SAndroid Build Coastguard Worker // restore the dispatch key set TLS at the same time.
43*da0073e9SAndroid Build Coastguard Worker c10::AutogradState::set_tls_state(state.autograd_tls_);
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker c10::impl::TorchDispatchModeTLS::set_state(state.torch_dispatch_mode_state_);
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker at::impl::PythonTorchFunctionTLS::set_state(state.python_torch_function_state_);
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker at::set_record_function_tls_(state.rf_tls_);
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker at::SavedTensorDefaultHooks::set_tls_state(state.saved_tensors_default_hooks_state_);
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker c10::impl::PythonDispatcherTLS::set_state(state.python_dispatcher_state_);
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_);
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_);
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker functorch::setFuncTorchTLS(state.functorch_tls_);
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker at::functionalization::impl::setFunctionalizationReapplyViewsTLS(state.functionalization_reapply_views_state_);
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker at::impl::ThreadLocalPythonObjects::set_state(state.saved_objects_);
64*da0073e9SAndroid Build Coastguard Worker #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
65*da0073e9SAndroid Build Coastguard Worker for(uint8_t i=0; i<state.autocast_dtypes_.size(); i++) {
66*da0073e9SAndroid Build Coastguard Worker at::autocast::set_autocast_dtype(static_cast<at::DeviceType>(i), state.autocast_dtypes_[i]);
67*da0073e9SAndroid Build Coastguard Worker }
68*da0073e9SAndroid Build Coastguard Worker #endif
69*da0073e9SAndroid Build Coastguard Worker }
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker } // namespace at
72