xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/PythonFallbackKernel.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/TorchDispatchUtils.h>
3 
4 
5 namespace at::impl {
6 
7 struct TORCH_API RestorePythonTLSSnapshot {
8   RestorePythonTLSSnapshot();
9   ~RestorePythonTLSSnapshot();
10 
11 private:
12   c10::impl::LocalDispatchKeySet saved_;
13   c10::impl::ForceDispatchKeyGuard guard_;
14 };
15 
16 
17 // RAII guard to make working with the above TLS safer.
18 struct TORCH_API MaybeSetTLSOnEntryGuard {
19 public:
20   MaybeSetTLSOnEntryGuard();
21   ~MaybeSetTLSOnEntryGuard();
22 
23 private:
24   bool value_set_;
25 };
26 
27 } // namespace at::impl
28