xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/cache.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /**
2  * Cache utils in this file is adapted from PyTorch/XLA
3  * https://github.com/pytorch/xla/blob/master/third_party/xla_client/cache.h
4  */
5 
6 #pragma once
7 
8 #include <functional>
9 #include <list>
10 #include <memory>
11 #include <mutex>
12 #include <unordered_map>
13 #include <utility>
14 
15 namespace torch {
16 namespace lazy {
17 
18 // Generic key and object cache with LRU expiration policy. The objects of type
19 // T will be stored as std::shared_ptr<T> and taken and returned as such, by the
20 // cache API.
21 template <
22     typename K,
23     typename T,
24     typename H = std::hash<K>,
25     typename E = std::equal_to<K>>
26 class Cache {
27  public:
28   using TypePtr = std::shared_ptr<T>;
29   using Element = std::pair<K, TypePtr>;
30 
Cache(size_t max_size)31   explicit Cache(size_t max_size) : max_size_(max_size) {}
32 
33   // Adds an object to the cache, unless it already exists. If the cache grows
34   // beyond the limit set during construction, the oldest used object will be
35   // removed from the cache.
Add(K key,TypePtr object)36   TypePtr Add(K key, TypePtr object) {
37     if (!max_size_) {
38       return object;
39     }
40     std::lock_guard<std::mutex> slock(lock_);
41     element_list_.emplace_front(Element(std::move(key), std::move(object)));
42     auto it = element_list_.begin();
43     auto emplace_result = element_map_.emplace(&it->first, it);
44     if (!emplace_result.second) {
45       element_list_.erase(it);
46       DoLRU(emplace_result.first->second);
47     } else if (element_list_.size() > max_size_) {
48       Element* last = &element_list_.back();
49       element_map_.erase(&last->first);
50       element_list_.pop_back();
51     }
52     return emplace_result.first->second->second;
53   }
54 
55   // Retrieves the existing object if it exists. If it does, its position in
56   // the LRU list gets moved to the head of the list.
57   // Returns nullptr if no object with the specified key is found within the
58   // cache.
Get(const K & key)59   TypePtr Get(const K& key) {
60     if (!max_size_) {
61       return nullptr;
62     }
63     std::lock_guard<std::mutex> slock(lock_);
64     auto it = element_map_.find(&key);
65     if (it == element_map_.end()) {
66       return nullptr;
67     }
68     DoLRU(it->second);
69     return it->second->second;
70   }
71 
GetLatest()72   TypePtr GetLatest() {
73     std::lock_guard<std::mutex> g(lock_);
74     TORCH_CHECK(!element_list_.empty());
75     return element_list_.front().second;
76   }
77 
Erase(const K & key)78   bool Erase(const K& key) {
79     if (!max_size_) {
80       return false;
81     }
82     std::lock_guard<std::mutex> slock(lock_);
83     auto it = element_map_.find(&key);
84     if (it == element_map_.end()) {
85       return false;
86     }
87     auto lit = it->second;
88     element_map_.erase(it);
89     element_list_.erase(lit);
90     return true;
91   }
92 
Clear()93   void Clear() {
94     if (!max_size_) {
95       return;
96     }
97     std::lock_guard<std::mutex> slock(lock_);
98     element_map_.clear();
99     element_list_.clear();
100   }
101 
Numel()102   int Numel() const {
103     if (!max_size_) {
104       return 0;
105     }
106     std::lock_guard<std::mutex> g(lock_);
107     TORCH_CHECK(element_map_.size() == element_list_.size());
108     return element_map_.size();
109   }
110 
111  private:
112   using ElementList = std::list<Element>;
113 
114   struct Hasher {
operatorHasher115     size_t operator()(const K* key) const {
116       return hasher(*key);
117     }
118 
119     H hasher;
120   };
121 
122   struct Equaler {
operatorEqualer123     bool operator()(const K* k1, const K* k2) const {
124       return equaler(*k1, *k2);
125     }
126 
127     E equaler;
128   };
129 
130   using ElementMap = std::
131       unordered_map<const K*, typename ElementList::iterator, Hasher, Equaler>;
132 
DoLRU(typename ElementList::iterator it)133   void DoLRU(typename ElementList::iterator it) {
134     element_list_.splice(element_list_.begin(), element_list_, it);
135   }
136 
137   mutable std::mutex lock_;
138   const size_t max_size_ = 0;
139   ElementList element_list_;
140   ElementMap element_map_;
141 };
142 
143 } // namespace lazy
144 } // namespace torch
145