1 #pragma once 2 3 #include <ATen/ATen.h> 4 5 namespace at::caching { 6 7 // Some systems (just cudagraphs currently) will persist a static tensor output 8 // whose TensorImpl does not change across iterations. For these tensors caching 9 // dtype conversions is invalid. Additionally, there will be an extra reference 10 // count to these cached tensors that would prevent buffer inplacing and other 11 // checks on tensor uniqueness. If we are not using these systems the enabled 12 // flag will be false and we will avoid the hash lookup. 13 14 TORCH_API bool is_cached_tensor(const at::Tensor& t); 15 TORCH_API void add_cached_tensor(const at::Tensor& t); 16 TORCH_API void remove_cached_tensor(const at::Tensor& t); 17 TORCH_API void set_cached_tensors_enabled(bool enable); 18 19 // For gradient buffer stealing we will adjust the use count of tensors 20 // which are persisted by cudagraphs, just as we need to adjust reference 21 // count of tensors with hooks. 22 TORCH_API size_t adjusted_use_count(const at::Tensor& t); 23 24 } // namespace at::caching 25