1 #include <ATen/ATen.h> 2 #include <ATen/CachedTensorUtils.h> 3 4 #include <c10/util/flat_hash_map.h> 5 6 namespace at::caching { 7 8 9 using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>; 10 11 bool cached_tensorimpls_enabled = false; 12 13 // Like `cached_casts` in autocast_mode, we hash on the TensorImpl* 14 // and keep the pointer alive with a weakref value. 15 ska::flat_hash_map<TensorImpl*, weakref_type> cached_tensorimpls; 16 std::mutex cached_tensorimpl_mutex; 17 18 is_cached_tensor(const at::Tensor & t)19bool is_cached_tensor(const at::Tensor& t) { 20 if (!cached_tensorimpls_enabled) { 21 return false; 22 } 23 const std::lock_guard<std::mutex> lock(cached_tensorimpl_mutex); 24 return cached_tensorimpls.count(t.unsafeGetTensorImpl()); 25 } 26 add_cached_tensor(const at::Tensor & t)27void add_cached_tensor(const at::Tensor& t) { 28 TORCH_INTERNAL_ASSERT(cached_tensorimpls_enabled); 29 const std::lock_guard<std::mutex> lock(cached_tensorimpl_mutex); 30 cached_tensorimpls.emplace(t.unsafeGetTensorImpl(), weakref_type(t.getIntrusivePtr())); 31 } 32 remove_cached_tensor(const at::Tensor & t)33void remove_cached_tensor(const at::Tensor& t) { 34 TORCH_INTERNAL_ASSERT(cached_tensorimpls_enabled); 35 const std::lock_guard<std::mutex> lock(cached_tensorimpl_mutex); 36 cached_tensorimpls.erase(t.unsafeGetTensorImpl()); 37 } 38 set_cached_tensors_enabled(bool enabled)39void set_cached_tensors_enabled(bool enabled) { 40 cached_tensorimpls_enabled = enabled; 41 } 42 adjusted_use_count(const at::Tensor & t)43size_t adjusted_use_count(const at::Tensor& t) { 44 return t.use_count() - (is_cached_tensor(t) ? 1 : 0); 45 } 46 47 } // namespace at::caching 48