xref: /aosp_15_r20/external/pytorch/aten/src/ATen/PythonTorchFunctionTLS.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/SafePyObject.h>
4 #include <c10/macros/Macros.h>
5 
6 namespace at::impl {
7 
8 enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED };
9 
10 struct TORCH_API PythonTorchFunctionTLS {
11   static void set_disabled_state(TorchFunctionDisabledState disabled_state_);
12   static TorchFunctionDisabledState get_disabled_state();
13 
14   static void push_onto_stack(std::shared_ptr<SafePyObject> mode);
15   static const std::shared_ptr<SafePyObject> pop_stack();
16   static const std::shared_ptr<SafePyObject>& get_stack_at(int64_t idx);
17   static int64_t stack_len();
18 
19   static const PythonTorchFunctionTLS& get_state();
20   static void set_state(const PythonTorchFunctionTLS& state);
21 
22  private:
23   // The mode TLS is split into
24   //   - disabled_state, which says which part of torch function are disabled
25   //   - stack_, which is a vector of modes representing the stack of user
26   //   defined modes
27   TorchFunctionDisabledState disabled_state_ =
28       TorchFunctionDisabledState::ENABLED;
29   std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
30 };
31 
32 TORCH_API bool torch_function_mode_enabled();
33 
34 TORCH_API bool torch_function_all_disabled();
35 
36 } // namespace at::impl
37