xref: /aosp_15_r20/external/pytorch/aten/src/ATen/FuncTorchTLS.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/macros/Macros.h>
4 #include <memory>
5 
6 namespace at::functorch {
7 
8 // NOTE [functorch TLS in pytorch/pytorch]
9 //
10 // functorch lives out-of-tree. However, it has some TLS that needs to be
11 // propagated. The solution for that is we store a pointer to the TLS
12 // inside pytorch/pytorch and extend FuncTorchTLSBase inside functorch to
13 // include whatever functorch needs.
14 //
15 // We need to store a pointer due to the indirection:
16 // inside functorch, we will create a subclass of FunctorchTLSBase called
17 // FuncTorchTLSImpl that actually contains metadata, like the DynamicLayerStack.
18 // FuncTorchTLSBase doesn't have any metadata because it hasn't been defined
19 // yet.
20 //
21 // Here in pytorch/pytorch, we will pass around FuncTorchTLSBase*, but inside
22 // functorch, we will assign a FuncTorchTLSImpl* to the FunctorchTLSBase*.
23 // We can't directly pass around FunctorchTLSBase (without a pointer) because
24 // FuncTorchTLSImpl does not fit inside a FuncTorchTLSBase by virtue of having
25 // more elements.
26 struct TORCH_API FuncTorchTLSBase {
27   virtual ~FuncTorchTLSBase() = default;
28   virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0;
29 
30   virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0;
31   virtual void checkSupportsCppAutogradFunction() const = 0;
32   virtual void checkSupportsInplaceRequiresGrad() const = 0;
33   virtual void checkSupportsRetainGrad() const = 0;
34 };
35 
36 // returns deepcopy of the functorch tls
37 TORCH_API std::unique_ptr<FuncTorchTLSBase> getCopyOfFuncTorchTLS();
38 
39 // sets the functorch tls. always does a deep copy.
40 TORCH_API void setFuncTorchTLS(
41     const std::shared_ptr<const FuncTorchTLSBase>& state);
42 
43 // get a mutable reference to the functorch tls
44 TORCH_API std::unique_ptr<FuncTorchTLSBase>& functorchTLSAccessor();
45 
46 } // namespace at::functorch
47