xref: /aosp_15_r20/external/pytorch/aten/src/ATen/SavedTensorHooks.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/SavedTensorHooks.h>
2 #include <c10/util/Exception.h>
3 #include <stack>
4 #include <utility>
5 #include <c10/core/SafePyObject.h>
6 
7 namespace at {
8 
9 namespace {
10   thread_local impl::SavedTensorDefaultHooksTLS tls;
11 
12   // This flag is set to true the first time default hooks are registered
13   // and left at true for the rest of the execution.
14   // It's an optimization so that users who never use default hooks don't need to
15   // read the thread_local variables pack_hook_ and unpack_hook_.
16   static bool is_initialized(false);
17 }
18 
assertSavedTensorHooksNotDisabled()19 static void assertSavedTensorHooksNotDisabled() {
20   TORCH_CHECK(SavedTensorDefaultHooks::is_enabled(), tls.disabled_error_message.value());
21 }
22 
is_enabled()23 bool SavedTensorDefaultHooks::is_enabled() {
24   // See NOTE: [disabled_error_message invariant]
25   return !tls.disabled_error_message.has_value();
26 }
27 
disable(const std::string & message)28 void SavedTensorDefaultHooks::disable(const std::string& message) {
29   tls.disabled_error_message = message;
30   if (!tls.stack.empty()) {
31     assertSavedTensorHooksNotDisabled();
32   }
33 }
34 
enable()35 void SavedTensorDefaultHooks::enable() {
36   tls.disabled_error_message = std::nullopt;
37 }
38 
set_tracing(bool is_tracing)39 /* static */ bool SavedTensorDefaultHooks::set_tracing(bool is_tracing) {
40   bool prior  = tls.is_tracing;
41   tls.is_tracing = is_tracing;
42   return prior;
43 }
44 
get_disabled_error_message()45 const std::optional<std::string>& SavedTensorDefaultHooks::get_disabled_error_message() {
46   return tls.disabled_error_message;
47 }
48 
get_tls_state()49 const impl::SavedTensorDefaultHooksTLS& SavedTensorDefaultHooks::get_tls_state() {
50   return tls;
51 }
52 
set_tls_state(const impl::SavedTensorDefaultHooksTLS & state)53 void SavedTensorDefaultHooks::set_tls_state(const impl::SavedTensorDefaultHooksTLS& state) {
54   tls = state;
55 }
56 
lazy_initialize()57 void SavedTensorDefaultHooks::lazy_initialize() {
58   is_initialized = true;
59 }
60 
push_hooks(SafePyObject pack_hook,SafePyObject unpack_hook)61 void SavedTensorDefaultHooks::push_hooks(SafePyObject pack_hook, SafePyObject unpack_hook) {
62   TORCH_INTERNAL_ASSERT(is_initialized);
63   assertSavedTensorHooksNotDisabled();
64   tls.stack.emplace(std::move(pack_hook), std::move(unpack_hook));
65 }
66 
pop_hooks()67 std::pair<SafePyObject, SafePyObject> SavedTensorDefaultHooks::pop_hooks() {
68   TORCH_INTERNAL_ASSERT(is_initialized && !tls.stack.empty());
69   std::pair<SafePyObject, SafePyObject> hooks = std::move(tls.stack.top());
70   tls.stack.pop();
71   return hooks;
72 }
73 
get_hooks()74 std::optional<std::pair<SafePyObject, SafePyObject>> SavedTensorDefaultHooks::get_hooks() {
75   // For tls.is_tracing, see NOTE: [Deferring tensor pack/unpack hooks until runtime]
76   if (!is_initialized || tls.stack.empty() || tls.is_tracing) {
77     return c10::nullopt;
78   }
79   return tls.stack.top();
80 }
81 
82 }
83