1 #pragma once 2 3 #include <c10/macros/Macros.h> 4 #include <memory> 5 6 namespace at::functorch { 7 8 // NOTE [functorch TLS in pytorch/pytorch] 9 // 10 // functorch lives out-of-tree. However, it has some TLS that needs to be 11 // propagated. The solution for that is we store a pointer to the TLS 12 // inside pytorch/pytorch and extend FuncTorchTLSBase inside functorch to 13 // include whatever functorch needs. 14 // 15 // We need to store a pointer due to the indirection: 16 // inside functorch, we will create a subclass of FunctorchTLSBase called 17 // FuncTorchTLSImpl that actually contains metadata, like the DynamicLayerStack. 18 // FuncTorchTLSBase doesn't have any metadata because it hasn't been defined 19 // yet. 20 // 21 // Here in pytorch/pytorch, we will pass around FuncTorchTLSBase*, but inside 22 // functorch, we will assign a FuncTorchTLSImpl* to the FunctorchTLSBase*. 23 // We can't directly pass around FunctorchTLSBase (without a pointer) because 24 // FuncTorchTLSImpl does not fit inside a FuncTorchTLSBase by virtue of having 25 // more elements. 26 struct TORCH_API FuncTorchTLSBase { 27 virtual ~FuncTorchTLSBase() = default; 28 virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0; 29 30 virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0; 31 virtual void checkSupportsCppAutogradFunction() const = 0; 32 virtual void checkSupportsInplaceRequiresGrad() const = 0; 33 virtual void checkSupportsRetainGrad() const = 0; 34 }; 35 36 // returns deepcopy of the functorch tls 37 TORCH_API std::unique_ptr<FuncTorchTLSBase> getCopyOfFuncTorchTLS(); 38 39 // sets the functorch tls. always does a deep copy. 40 TORCH_API void setFuncTorchTLS( 41 const std::shared_ptr<const FuncTorchTLSBase>& state); 42 43 // get a mutable reference to the functorch tls 44 TORCH_API std::unique_ptr<FuncTorchTLSBase>& functorchTLSAccessor(); 45 46 } // namespace at::functorch 47