#include #include #include namespace at::caching { using weakref_type = c10::weak_intrusive_ptr; bool cached_tensorimpls_enabled = false; // Like `cached_casts` in autocast_mode, we hash on the TensorImpl* // and keep the pointer alive with a weakref value. ska::flat_hash_map cached_tensorimpls; std::mutex cached_tensorimpl_mutex; bool is_cached_tensor(const at::Tensor& t) { if (!cached_tensorimpls_enabled) { return false; } const std::lock_guard lock(cached_tensorimpl_mutex); return cached_tensorimpls.count(t.unsafeGetTensorImpl()); } void add_cached_tensor(const at::Tensor& t) { TORCH_INTERNAL_ASSERT(cached_tensorimpls_enabled); const std::lock_guard lock(cached_tensorimpl_mutex); cached_tensorimpls.emplace(t.unsafeGetTensorImpl(), weakref_type(t.getIntrusivePtr())); } void remove_cached_tensor(const at::Tensor& t) { TORCH_INTERNAL_ASSERT(cached_tensorimpls_enabled); const std::lock_guard lock(cached_tensorimpl_mutex); cached_tensorimpls.erase(t.unsafeGetTensorImpl()); } void set_cached_tensors_enabled(bool enabled) { cached_tensorimpls_enabled = enabled; } size_t adjusted_use_count(const at::Tensor& t) { return t.use_count() - (is_cached_tensor(t) ? 1 : 0); } } // namespace at::caching