1 #pragma once 2 3 #include <ATen/PythonTorchFunctionTLS.h> 4 5 namespace torch::overrides { 6 7 struct StashTorchFunctionModeGuard { StashTorchFunctionModeGuardStashTorchFunctionModeGuard8 StashTorchFunctionModeGuard() { 9 cur_mode_ = at::impl::PythonTorchFunctionTLS::pop_stack(); 10 } ~StashTorchFunctionModeGuardStashTorchFunctionModeGuard11 ~StashTorchFunctionModeGuard() { 12 at::impl::PythonTorchFunctionTLS::push_onto_stack(cur_mode_); 13 } 14 get_cur_modeStashTorchFunctionModeGuard15 const std::shared_ptr<c10::SafePyObject>& get_cur_mode() { 16 return cur_mode_; 17 } 18 19 private: 20 std::shared_ptr<c10::SafePyObject> cur_mode_; 21 }; 22 23 } // namespace torch::overrides 24