xref: /aosp_15_r20/external/pytorch/aten/src/ATen/CachedTensorUtils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/CachedTensorUtils.h>
3 
4 #include <c10/util/flat_hash_map.h>
5 
6 namespace at::caching {
7 
8 
9 using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
10 
11 bool cached_tensorimpls_enabled = false;
12 
13 // Like `cached_casts` in autocast_mode, we hash on the TensorImpl*
14 //  and keep the pointer alive with a weakref value.
15 ska::flat_hash_map<TensorImpl*, weakref_type> cached_tensorimpls;
16 std::mutex cached_tensorimpl_mutex;
17 
18 
is_cached_tensor(const at::Tensor & t)19 bool is_cached_tensor(const at::Tensor& t) {
20   if (!cached_tensorimpls_enabled) {
21     return false;
22   }
23   const std::lock_guard<std::mutex> lock(cached_tensorimpl_mutex);
24   return cached_tensorimpls.count(t.unsafeGetTensorImpl());
25 }
26 
add_cached_tensor(const at::Tensor & t)27 void add_cached_tensor(const at::Tensor& t) {
28   TORCH_INTERNAL_ASSERT(cached_tensorimpls_enabled);
29   const std::lock_guard<std::mutex> lock(cached_tensorimpl_mutex);
30   cached_tensorimpls.emplace(t.unsafeGetTensorImpl(), weakref_type(t.getIntrusivePtr()));
31 }
32 
remove_cached_tensor(const at::Tensor & t)33 void remove_cached_tensor(const at::Tensor& t) {
34   TORCH_INTERNAL_ASSERT(cached_tensorimpls_enabled);
35   const std::lock_guard<std::mutex> lock(cached_tensorimpl_mutex);
36   cached_tensorimpls.erase(t.unsafeGetTensorImpl());
37 }
38 
set_cached_tensors_enabled(bool enabled)39 void set_cached_tensors_enabled(bool enabled) {
40   cached_tensorimpls_enabled = enabled;
41 }
42 
adjusted_use_count(const at::Tensor & t)43 size_t adjusted_use_count(const at::Tensor& t) {
44   return t.use_count() - (is_cached_tensor(t) ? 1 : 0);
45 }
46 
47 } // namespace at::caching
48