xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/python_torch_function_mode.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/PythonTorchFunctionTLS.h>
4 
5 namespace torch::overrides {
6 
7 struct StashTorchFunctionModeGuard {
StashTorchFunctionModeGuardStashTorchFunctionModeGuard8   StashTorchFunctionModeGuard() {
9     cur_mode_ = at::impl::PythonTorchFunctionTLS::pop_stack();
10   }
~StashTorchFunctionModeGuardStashTorchFunctionModeGuard11   ~StashTorchFunctionModeGuard() {
12     at::impl::PythonTorchFunctionTLS::push_onto_stack(cur_mode_);
13   }
14 
get_cur_modeStashTorchFunctionModeGuard15   const std::shared_ptr<c10::SafePyObject>& get_cur_mode() {
16     return cur_mode_;
17   }
18 
19  private:
20   std::shared_ptr<c10::SafePyObject> cur_mode_;
21 };
22 
23 } // namespace torch::overrides
24