xref: /aosp_15_r20/external/pytorch/aten/src/ATen/CachedTensorUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 
5 namespace at::caching {
6 
7 // Some systems (just cudagraphs currently) will persist a static tensor output
8 // whose TensorImpl does not change across iterations. For these tensors caching
9 // dtype conversions is invalid. Additionally, there will be an extra reference
10 // count to these cached tensors that would prevent buffer inplacing and other
11 // checks on tensor uniqueness. If we are not using these systems the enabled
12 // flag will be false and we will avoid the hash lookup.
13 
14 TORCH_API bool is_cached_tensor(const at::Tensor& t);
15 TORCH_API void add_cached_tensor(const at::Tensor& t);
16 TORCH_API void remove_cached_tensor(const at::Tensor& t);
17 TORCH_API void set_cached_tensors_enabled(bool enable);
18 
19 // For gradient buffer stealing we will adjust the use count of tensors
20 // which are persisted by cudagraphs, just as we need to adjust reference
21 // count of tensors with hooks.
22 TORCH_API size_t adjusted_use_count(const at::Tensor& t);
23 
24 } // namespace at::caching
25