1 /** 2 * Cache utils in this file is adapted from PyTorch/XLA 3 * https://github.com/pytorch/xla/blob/master/third_party/xla_client/cache.h 4 */ 5 6 #pragma once 7 8 #include <functional> 9 #include <list> 10 #include <memory> 11 #include <mutex> 12 #include <unordered_map> 13 #include <utility> 14 15 namespace torch { 16 namespace lazy { 17 18 // Generic key and object cache with LRU expiration policy. The objects of type 19 // T will be stored as std::shared_ptr<T> and taken and returned as such, by the 20 // cache API. 21 template < 22 typename K, 23 typename T, 24 typename H = std::hash<K>, 25 typename E = std::equal_to<K>> 26 class Cache { 27 public: 28 using TypePtr = std::shared_ptr<T>; 29 using Element = std::pair<K, TypePtr>; 30 Cache(size_t max_size)31 explicit Cache(size_t max_size) : max_size_(max_size) {} 32 33 // Adds an object to the cache, unless it already exists. If the cache grows 34 // beyond the limit set during construction, the oldest used object will be 35 // removed from the cache. Add(K key,TypePtr object)36 TypePtr Add(K key, TypePtr object) { 37 if (!max_size_) { 38 return object; 39 } 40 std::lock_guard<std::mutex> slock(lock_); 41 element_list_.emplace_front(Element(std::move(key), std::move(object))); 42 auto it = element_list_.begin(); 43 auto emplace_result = element_map_.emplace(&it->first, it); 44 if (!emplace_result.second) { 45 element_list_.erase(it); 46 DoLRU(emplace_result.first->second); 47 } else if (element_list_.size() > max_size_) { 48 Element* last = &element_list_.back(); 49 element_map_.erase(&last->first); 50 element_list_.pop_back(); 51 } 52 return emplace_result.first->second->second; 53 } 54 55 // Retrieves the existing object if it exists. If it does, its position in 56 // the LRU list gets moved to the head of the list. 57 // Returns nullptr if no object with the specified key is found within the 58 // cache. Get(const K & key)59 TypePtr Get(const K& key) { 60 if (!max_size_) { 61 return nullptr; 62 } 63 std::lock_guard<std::mutex> slock(lock_); 64 auto it = element_map_.find(&key); 65 if (it == element_map_.end()) { 66 return nullptr; 67 } 68 DoLRU(it->second); 69 return it->second->second; 70 } 71 GetLatest()72 TypePtr GetLatest() { 73 std::lock_guard<std::mutex> g(lock_); 74 TORCH_CHECK(!element_list_.empty()); 75 return element_list_.front().second; 76 } 77 Erase(const K & key)78 bool Erase(const K& key) { 79 if (!max_size_) { 80 return false; 81 } 82 std::lock_guard<std::mutex> slock(lock_); 83 auto it = element_map_.find(&key); 84 if (it == element_map_.end()) { 85 return false; 86 } 87 auto lit = it->second; 88 element_map_.erase(it); 89 element_list_.erase(lit); 90 return true; 91 } 92 Clear()93 void Clear() { 94 if (!max_size_) { 95 return; 96 } 97 std::lock_guard<std::mutex> slock(lock_); 98 element_map_.clear(); 99 element_list_.clear(); 100 } 101 Numel()102 int Numel() const { 103 if (!max_size_) { 104 return 0; 105 } 106 std::lock_guard<std::mutex> g(lock_); 107 TORCH_CHECK(element_map_.size() == element_list_.size()); 108 return element_map_.size(); 109 } 110 111 private: 112 using ElementList = std::list<Element>; 113 114 struct Hasher { operatorHasher115 size_t operator()(const K* key) const { 116 return hasher(*key); 117 } 118 119 H hasher; 120 }; 121 122 struct Equaler { operatorEqualer123 bool operator()(const K* k1, const K* k2) const { 124 return equaler(*k1, *k2); 125 } 126 127 E equaler; 128 }; 129 130 using ElementMap = std:: 131 unordered_map<const K*, typename ElementList::iterator, Hasher, Equaler>; 132 DoLRU(typename ElementList::iterator it)133 void DoLRU(typename ElementList::iterator it) { 134 element_list_.splice(element_list_.begin(), element_list_, it); 135 } 136 137 mutable std::mutex lock_; 138 const size_t max_size_ = 0; 139 ElementList element_list_; 140 ElementMap element_map_; 141 }; 142 143 } // namespace lazy 144 } // namespace torch 145