1 #include <ATen/SavedTensorHooks.h> 2 #include <c10/util/Exception.h> 3 #include <stack> 4 #include <utility> 5 #include <c10/core/SafePyObject.h> 6 7 namespace at { 8 9 namespace { 10 thread_local impl::SavedTensorDefaultHooksTLS tls; 11 12 // This flag is set to true the first time default hooks are registered 13 // and left at true for the rest of the execution. 14 // It's an optimization so that users who never use default hooks don't need to 15 // read the thread_local variables pack_hook_ and unpack_hook_. 16 static bool is_initialized(false); 17 } 18 assertSavedTensorHooksNotDisabled()19static void assertSavedTensorHooksNotDisabled() { 20 TORCH_CHECK(SavedTensorDefaultHooks::is_enabled(), tls.disabled_error_message.value()); 21 } 22 is_enabled()23bool SavedTensorDefaultHooks::is_enabled() { 24 // See NOTE: [disabled_error_message invariant] 25 return !tls.disabled_error_message.has_value(); 26 } 27 disable(const std::string & message)28void SavedTensorDefaultHooks::disable(const std::string& message) { 29 tls.disabled_error_message = message; 30 if (!tls.stack.empty()) { 31 assertSavedTensorHooksNotDisabled(); 32 } 33 } 34 enable()35void SavedTensorDefaultHooks::enable() { 36 tls.disabled_error_message = std::nullopt; 37 } 38 set_tracing(bool is_tracing)39/* static */ bool SavedTensorDefaultHooks::set_tracing(bool is_tracing) { 40 bool prior = tls.is_tracing; 41 tls.is_tracing = is_tracing; 42 return prior; 43 } 44 get_disabled_error_message()45const std::optional<std::string>& SavedTensorDefaultHooks::get_disabled_error_message() { 46 return tls.disabled_error_message; 47 } 48 get_tls_state()49const impl::SavedTensorDefaultHooksTLS& SavedTensorDefaultHooks::get_tls_state() { 50 return tls; 51 } 52 set_tls_state(const impl::SavedTensorDefaultHooksTLS & state)53void SavedTensorDefaultHooks::set_tls_state(const impl::SavedTensorDefaultHooksTLS& state) { 54 tls = state; 55 } 56 lazy_initialize()57void SavedTensorDefaultHooks::lazy_initialize() { 58 is_initialized = true; 59 } 60 push_hooks(SafePyObject pack_hook,SafePyObject unpack_hook)61void SavedTensorDefaultHooks::push_hooks(SafePyObject pack_hook, SafePyObject unpack_hook) { 62 TORCH_INTERNAL_ASSERT(is_initialized); 63 assertSavedTensorHooksNotDisabled(); 64 tls.stack.emplace(std::move(pack_hook), std::move(unpack_hook)); 65 } 66 pop_hooks()67std::pair<SafePyObject, SafePyObject> SavedTensorDefaultHooks::pop_hooks() { 68 TORCH_INTERNAL_ASSERT(is_initialized && !tls.stack.empty()); 69 std::pair<SafePyObject, SafePyObject> hooks = std::move(tls.stack.top()); 70 tls.stack.pop(); 71 return hooks; 72 } 73 get_hooks()74std::optional<std::pair<SafePyObject, SafePyObject>> SavedTensorDefaultHooks::get_hooks() { 75 // For tls.is_tracing, see NOTE: [Deferring tensor pack/unpack hooks until runtime] 76 if (!is_initialized || tls.stack.empty() || tls.is_tracing) { 77 return c10::nullopt; 78 } 79 return tls.stack.top(); 80 } 81 82 } 83