1 #pragma once 2 3 #include <c10/core/AutogradState.h> 4 #include <c10/macros/Export.h> 5 6 namespace c10 { 7 8 struct C10_API GradMode { 9 static bool is_enabled(); 10 static void set_enabled(bool enabled); 11 }; 12 13 // A RAII, thread local (!) guard that enables or disables grad mode upon 14 // construction, and sets it back to the original value upon destruction. 15 struct C10_API AutoGradMode { AutoGradModeAutoGradMode16 AutoGradMode(bool enabled) : prev_mode(GradMode::is_enabled()) { 17 GradMode::set_enabled(enabled); 18 } ~AutoGradModeAutoGradMode19 ~AutoGradMode() { 20 GradMode::set_enabled(prev_mode); 21 } 22 bool prev_mode; 23 }; 24 25 // A RAII, thread local (!) guard that stops future operations from building 26 // gradients. 27 struct C10_API NoGradGuard : public AutoGradMode { NoGradGuardNoGradGuard28 NoGradGuard() : AutoGradMode(/*enabled=*/false) {} 29 }; 30 31 // A RAII, thread local (!) guard that enables or disables forward grad mode 32 // upon construction, and sets it back to the original value upon destruction. 33 struct C10_API AutoFwGradMode { AutoFwGradModeAutoFwGradMode34 AutoFwGradMode(bool enabled) 35 : prev_mode(AutogradState::get_tls_state().get_fw_grad_mode()) { 36 AutogradState::get_tls_state().set_fw_grad_mode(enabled); 37 } ~AutoFwGradModeAutoFwGradMode38 ~AutoFwGradMode() { 39 AutogradState::get_tls_state().set_fw_grad_mode(prev_mode); 40 } 41 bool prev_mode; 42 }; 43 44 } // namespace c10 45