#include namespace at::functorch { namespace { thread_local std::unique_ptr kFuncTorchTLS = nullptr; } std::unique_ptr getCopyOfFuncTorchTLS() { if (kFuncTorchTLS == nullptr) { return nullptr; } return kFuncTorchTLS->deepcopy(); } void setFuncTorchTLS(const std::shared_ptr& state) { if (state == nullptr) { kFuncTorchTLS = nullptr; return; } kFuncTorchTLS = state->deepcopy(); } std::unique_ptr& functorchTLSAccessor() { return kFuncTorchTLS; } } // namespace at::functorch