1 #include <ATen/FuncTorchTLS.h> 2 3 namespace at::functorch { 4 5 namespace { 6 7 thread_local std::unique_ptr<FuncTorchTLSBase> kFuncTorchTLS = nullptr; 8 9 } 10 getCopyOfFuncTorchTLS()11std::unique_ptr<FuncTorchTLSBase> getCopyOfFuncTorchTLS() { 12 if (kFuncTorchTLS == nullptr) { 13 return nullptr; 14 } 15 return kFuncTorchTLS->deepcopy(); 16 } 17 setFuncTorchTLS(const std::shared_ptr<const FuncTorchTLSBase> & state)18void setFuncTorchTLS(const std::shared_ptr<const FuncTorchTLSBase>& state) { 19 if (state == nullptr) { 20 kFuncTorchTLS = nullptr; 21 return; 22 } 23 kFuncTorchTLS = state->deepcopy(); 24 } 25 functorchTLSAccessor()26std::unique_ptr<FuncTorchTLSBase>& functorchTLSAccessor() { 27 return kFuncTorchTLS; 28 } 29 30 31 } // namespace at::functorch 32