xref: /aosp_15_r20/external/pytorch/aten/src/ATen/FuncTorchTLS.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/FuncTorchTLS.h>
2 
3 namespace at::functorch {
4 
5 namespace {
6 
7 thread_local std::unique_ptr<FuncTorchTLSBase> kFuncTorchTLS = nullptr;
8 
9 }
10 
getCopyOfFuncTorchTLS()11 std::unique_ptr<FuncTorchTLSBase> getCopyOfFuncTorchTLS() {
12   if (kFuncTorchTLS == nullptr) {
13     return nullptr;
14   }
15   return kFuncTorchTLS->deepcopy();
16 }
17 
setFuncTorchTLS(const std::shared_ptr<const FuncTorchTLSBase> & state)18 void setFuncTorchTLS(const std::shared_ptr<const FuncTorchTLSBase>& state) {
19   if (state == nullptr) {
20     kFuncTorchTLS = nullptr;
21     return;
22   }
23   kFuncTorchTLS = state->deepcopy();
24 }
25 
functorchTLSAccessor()26 std::unique_ptr<FuncTorchTLSBase>& functorchTLSAccessor() {
27   return kFuncTorchTLS;
28 }
29 
30 
31 } // namespace at::functorch
32