1 #include <ATen/ThreadLocalState.h>
2
3 #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
4 #include <ATen/autocast_mode.h>
5 #include <ATen/core/grad_mode.h>
6 #endif
7
8 #include <ATen/record_function.h>
9 #include <ATen/SavedTensorHooks.h>
10 #include <ATen/FunctionalTensorWrapper.h>
11
12 namespace at {
13
ThreadLocalState()14 ThreadLocalState::ThreadLocalState()
15 : dispatch_key_(c10::impl::tls_local_dispatch_key_set()),
16 debug_info_(c10::ThreadLocalDebugInfo::current()),
17 rf_tls_(at::get_record_function_tls_()), functorch_tls_(functorch::getCopyOfFuncTorchTLS()),
18 autograd_tls_(c10::AutogradState::get_tls_state()),
19 torch_dispatch_mode_state_(c10::impl::TorchDispatchModeTLS::get_state()), python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()),
20 python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()),
21 saved_tensors_default_hooks_state_(at::SavedTensorDefaultHooks::get_tls_state()), functionalization_reapply_views_state_(at::functionalization::impl::getFunctionalizationReapplyViewsTLS()),
22 saved_objects_(at::impl::ThreadLocalPythonObjects::get_state()) {
23 #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
24 for(uint8_t i=0; i<autocast_dtypes_.size(); i++) {
25 autocast_dtypes_[i] = at::autocast::get_autocast_dtype(static_cast<at::DeviceType>(i));
26 }
27 #endif
28 }
29
set_grad_mode(bool enabled)30 void ThreadLocalState::set_grad_mode(bool enabled) {
31 autograd_tls_.set_grad_mode(enabled);
32 }
33
set_multithreading_enabled(bool enabled)34 void ThreadLocalState::set_multithreading_enabled(bool enabled) {
35 autograd_tls_.set_multithreading_enabled(enabled);
36 }
37
38 /* static */
setThreadLocalState(const ThreadLocalState & state)39 void ThreadLocalState::setThreadLocalState(
40 const ThreadLocalState& state) {
41 // Note that setting the InferenceMode TLS in this function is ONLY ok because we always
42 // restore the dispatch key set TLS at the same time.
43 c10::AutogradState::set_tls_state(state.autograd_tls_);
44
45 c10::impl::TorchDispatchModeTLS::set_state(state.torch_dispatch_mode_state_);
46
47 at::impl::PythonTorchFunctionTLS::set_state(state.python_torch_function_state_);
48
49 at::set_record_function_tls_(state.rf_tls_);
50
51 at::SavedTensorDefaultHooks::set_tls_state(state.saved_tensors_default_hooks_state_);
52
53 c10::impl::PythonDispatcherTLS::set_state(state.python_dispatcher_state_);
54
55 c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_);
56
57 c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_);
58
59 functorch::setFuncTorchTLS(state.functorch_tls_);
60
61 at::functionalization::impl::setFunctionalizationReapplyViewsTLS(state.functionalization_reapply_views_state_);
62
63 at::impl::ThreadLocalPythonObjects::set_state(state.saved_objects_);
64 #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
65 for(uint8_t i=0; i<state.autocast_dtypes_.size(); i++) {
66 at::autocast::set_autocast_dtype(static_cast<at::DeviceType>(i), state.autocast_dtypes_[i]);
67 }
68 #endif
69 }
70
71 } // namespace at
72