1 #pragma once 2 3 #include <c10/core/SafePyObject.h> 4 #include <c10/macros/Macros.h> 5 6 namespace at::impl { 7 8 enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED }; 9 10 struct TORCH_API PythonTorchFunctionTLS { 11 static void set_disabled_state(TorchFunctionDisabledState disabled_state_); 12 static TorchFunctionDisabledState get_disabled_state(); 13 14 static void push_onto_stack(std::shared_ptr<SafePyObject> mode); 15 static const std::shared_ptr<SafePyObject> pop_stack(); 16 static const std::shared_ptr<SafePyObject>& get_stack_at(int64_t idx); 17 static int64_t stack_len(); 18 19 static const PythonTorchFunctionTLS& get_state(); 20 static void set_state(const PythonTorchFunctionTLS& state); 21 22 private: 23 // The mode TLS is split into 24 // - disabled_state, which says which part of torch function are disabled 25 // - stack_, which is a vector of modes representing the stack of user 26 // defined modes 27 TorchFunctionDisabledState disabled_state_ = 28 TorchFunctionDisabledState::ENABLED; 29 std::vector<std::shared_ptr<c10::SafePyObject>> stack_; 30 }; 31 32 TORCH_API bool torch_function_mode_enabled(); 33 34 TORCH_API bool torch_function_all_disabled(); 35 36 } // namespace at::impl 37