xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ThreadLocalState.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ThreadLocalState.h>
2 
3 #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
4 #include <ATen/autocast_mode.h>
5 #include <ATen/core/grad_mode.h>
6 #endif
7 
8 #include <ATen/record_function.h>
9 #include <ATen/SavedTensorHooks.h>
10 #include <ATen/FunctionalTensorWrapper.h>
11 
12 namespace at {
13 
ThreadLocalState()14 ThreadLocalState::ThreadLocalState()
15     : dispatch_key_(c10::impl::tls_local_dispatch_key_set()),
16       debug_info_(c10::ThreadLocalDebugInfo::current()),
17       rf_tls_(at::get_record_function_tls_()), functorch_tls_(functorch::getCopyOfFuncTorchTLS()),
18       autograd_tls_(c10::AutogradState::get_tls_state()),
19       torch_dispatch_mode_state_(c10::impl::TorchDispatchModeTLS::get_state()), python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()),
20       python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()),
21       saved_tensors_default_hooks_state_(at::SavedTensorDefaultHooks::get_tls_state()), functionalization_reapply_views_state_(at::functionalization::impl::getFunctionalizationReapplyViewsTLS()),
22       saved_objects_(at::impl::ThreadLocalPythonObjects::get_state()) {
23 #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
24   for(uint8_t i=0; i<autocast_dtypes_.size(); i++) {
25      autocast_dtypes_[i] = at::autocast::get_autocast_dtype(static_cast<at::DeviceType>(i));
26   }
27 #endif
28 }
29 
set_grad_mode(bool enabled)30 void ThreadLocalState::set_grad_mode(bool enabled) {
31   autograd_tls_.set_grad_mode(enabled);
32 }
33 
set_multithreading_enabled(bool enabled)34 void ThreadLocalState::set_multithreading_enabled(bool enabled) {
35   autograd_tls_.set_multithreading_enabled(enabled);
36 }
37 
38 /* static */
setThreadLocalState(const ThreadLocalState & state)39 void ThreadLocalState::setThreadLocalState(
40     const ThreadLocalState& state) {
41   // Note that setting the InferenceMode TLS in this function is ONLY ok because we always
42   // restore the dispatch key set TLS at the same time.
43   c10::AutogradState::set_tls_state(state.autograd_tls_);
44 
45   c10::impl::TorchDispatchModeTLS::set_state(state.torch_dispatch_mode_state_);
46 
47   at::impl::PythonTorchFunctionTLS::set_state(state.python_torch_function_state_);
48 
49   at::set_record_function_tls_(state.rf_tls_);
50 
51   at::SavedTensorDefaultHooks::set_tls_state(state.saved_tensors_default_hooks_state_);
52 
53   c10::impl::PythonDispatcherTLS::set_state(state.python_dispatcher_state_);
54 
55   c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_);
56 
57   c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_);
58 
59   functorch::setFuncTorchTLS(state.functorch_tls_);
60 
61   at::functionalization::impl::setFunctionalizationReapplyViewsTLS(state.functionalization_reapply_views_state_);
62 
63   at::impl::ThreadLocalPythonObjects::set_state(state.saved_objects_);
64 #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
65   for(uint8_t i=0; i<state.autocast_dtypes_.size(); i++) {
66      at::autocast::set_autocast_dtype(static_cast<at::DeviceType>(i), state.autocast_dtypes_[i]);
67   }
68 #endif
69 }
70 
71 } // namespace at
72