1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/weakref_lru_cache.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "absl/cleanup/cleanup.h"
22 #include "absl/synchronization/notification.h"
23 #include "pybind11/pybind11.h"
24 #include "tensorflow/compiler/xla/pjrt/lru_cache.h"
25
26 namespace jax {
27
28 class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
29 public:
30 struct Key {
31 pybind11::object context;
32 pybind11::args args;
33 pybind11::kwargs kwargs;
34
operator ==jax::WeakrefLRUCache::Key35 bool operator==(const Key& other) const {
36 if (!context.equal(other.context)) return false;
37 if (!args.equal(other.args)) return false;
38 if (!kwargs.equal(other.kwargs)) return false;
39 return true;
40 }
41
42 template <typename H>
AbslHashValue(H h,const Key & key)43 friend H AbslHashValue(H h, const Key& key) {
44 h = H::combine(std::move(h), pybind11::hash(key.context));
45 h = H::combine(std::move(h), pybind11::hash(key.args));
46 h = H::combine(std::move(h), key.kwargs.size());
47 for (auto& kv : key.kwargs) {
48 h = H::combine(std::move(h), pybind11::hash(kv.first));
49 h = H::combine(std::move(h), pybind11::hash(kv.second));
50 }
51 return h;
52 }
53 };
54
55 struct CacheEntry {
56 bool has_result = false;
57 pybind11::object result;
58 absl::Notification completed;
59 };
60
61 struct CacheInfo {
62 int64_t hits;
63 int64_t misses;
64 int64_t maxsize;
65 int64_t currsize;
66 };
67
68 struct UnboundWeakrefCacheEntry {
69 pybind11::handle object;
70 WeakrefLRUCache* cache;
71 size_t cached_hash;
72 };
73
74 struct WeakrefCacheEntry {
75 pybind11::weakref weakref;
76 size_t cached_hash;
77 };
78
79 struct WeakrefKeyHash {
80 using is_transparent = void;
81
operator ()jax::WeakrefLRUCache::WeakrefKeyHash82 size_t operator()(const UnboundWeakrefCacheEntry& v) const {
83 return v.cached_hash;
84 }
operator ()jax::WeakrefLRUCache::WeakrefKeyHash85 size_t operator()(const WeakrefCacheEntry& v) const {
86 return v.cached_hash;
87 }
88 };
89
90 struct WeakrefKeyEq {
91 using is_transparent = void;
operator ()jax::WeakrefLRUCache::WeakrefKeyEq92 bool operator()(const WeakrefCacheEntry& lhs,
93 const WeakrefCacheEntry& rhs) const {
94 return lhs.weakref.equal(rhs.weakref);
95 }
operator ()jax::WeakrefLRUCache::WeakrefKeyEq96 bool operator()(const WeakrefCacheEntry& lhs,
97 const UnboundWeakrefCacheEntry& rhs) const {
98 PyObject* obj = PyWeakref_GET_OBJECT(lhs.weakref.ptr());
99 if (obj == Py_None) {
100 return false;
101 }
102 return pybind11::reinterpret_borrow<pybind11::object>(obj).equal(
103 rhs.object);
104 }
105 };
106
107 using Cache = xla::LRUCache<Key, std::shared_ptr<CacheEntry>>;
WeakrefLRUCache(pybind11::function cache_context_fn,pybind11::function fn,int64_t maxsize)108 WeakrefLRUCache(pybind11::function cache_context_fn, pybind11::function fn,
109 int64_t maxsize)
110 : cache_context_fn_(cache_context_fn), fn_(fn), lru_list_(maxsize) {}
111
GetCache(const UnboundWeakrefCacheEntry & key)112 std::shared_ptr<Cache> GetCache(const UnboundWeakrefCacheEntry& key) {
113 auto it = entries_.find(key);
114 if (it != entries_.end()) {
115 return (it->second);
116 }
117 pybind11::weakref weakref(
118 key.object, pybind11::cpp_function([this_weak = weak_from_this(),
119 cached_hash = key.cached_hash](
120 pybind11::handle weakref) {
121 auto cache = this_weak.lock();
122 if (cache == nullptr) {
123 return;
124 }
125 cache->entries_.erase(WeakrefCacheEntry{
126 pybind11::reinterpret_borrow<pybind11::weakref>(weakref),
127 cached_hash});
128 }));
129 return (entries_
130 .emplace(WeakrefCacheEntry{std::move(weakref), key.cached_hash},
131 std::make_shared<Cache>(&lru_list_))
132 .first->second);
133 }
134
Call(pybind11::object weakref_key,pybind11::args args,pybind11::kwargs kwargs)135 pybind11::object Call(pybind11::object weakref_key, pybind11::args args,
136 pybind11::kwargs kwargs) {
137 pybind11::object context = cache_context_fn_();
138 std::shared_ptr<Cache> cache_ptr = GetCache(UnboundWeakrefCacheEntry{
139 weakref_key, this, static_cast<size_t>(pybind11::hash(weakref_key))});
140 Cache& cache = *cache_ptr;
141 ++total_queries_;
142
143 bool inserted = false;
144 Key key{context, args, kwargs};
145 auto entry = cache.GetOrCreateIfAbsent(key, [&inserted](const Key& key) {
146 inserted = true;
147 return std::make_shared<CacheEntry>();
148 });
149 if (!entry->completed.HasBeenNotified()) {
150 if (inserted) {
151 ++misses_;
152 absl::Cleanup notify = [&] { entry->completed.Notify(); };
153 entry->result = fn_(weakref_key, *args, **kwargs);
154 entry->has_result = true;
155 } else {
156 pybind11::gil_scoped_release release;
157 entry->completed.WaitForNotification();
158 }
159 }
160
161 if (entry->has_result) {
162 return entry->result;
163 } else {
164 ++misses_;
165 return fn_(weakref_key, *args, **kwargs);
166 }
167 }
GetCacheInfo() const168 CacheInfo GetCacheInfo() const {
169 CacheInfo result;
170 result.hits = total_queries_ - misses_;
171 result.misses = misses_;
172 result.maxsize = lru_list_.Capacity();
173 result.currsize = lru_list_.Size();
174 return result;
175 }
Clear()176 void Clear() {
177 total_queries_ = misses_ = 0;
178 entries_.clear();
179 }
180
181 pybind11::function cache_context_fn_;
182 pybind11::function fn_;
183 Cache::LRUList lru_list_;
184 absl::node_hash_map<WeakrefCacheEntry, std::shared_ptr<Cache>, WeakrefKeyHash,
185 WeakrefKeyEq>
186 entries_;
187 int64_t misses_ = 0;
188 int64_t total_queries_ = 0;
189 };
190
191 namespace {
192 namespace py = ::pybind11;
193 } // namespace
194
BuildWeakrefLRUCacheAPI(pybind11::module & m)195 void BuildWeakrefLRUCacheAPI(pybind11::module& m) {
196 auto weakref_lru_cache =
197 py::class_<WeakrefLRUCache, std::shared_ptr<WeakrefLRUCache>>(
198 m, "WeakrefLRUCache")
199 .def("__call__", &WeakrefLRUCache::Call)
200 .def("cache_info", &WeakrefLRUCache::GetCacheInfo)
201 .def("cache_clear", &WeakrefLRUCache::Clear);
202 py::class_<WeakrefLRUCache::CacheInfo>(weakref_lru_cache,
203 "WeakrefLRUCacheInfo")
204 .def_readonly("hits", &WeakrefLRUCache::CacheInfo::hits)
205 .def_readonly("misses", &WeakrefLRUCache::CacheInfo::misses)
206 .def_readonly("maxsize", &WeakrefLRUCache::CacheInfo::maxsize)
207 .def_readonly("currsize", &WeakrefLRUCache::CacheInfo::currsize)
208 .def("__repr__", [](WeakrefLRUCache::CacheInfo& info) {
209 return absl::StrCat(
210 "WeakrefLRUCache(hits=", info.hits, ", misses=", info.misses,
211 ", maxsize=", info.maxsize, ", currsize=", info.currsize, ")");
212 });
213 m.def(
214 "weakref_lru_cache",
215 [](pybind11::function cache_context_fn, pybind11::function fn,
216 int64_t maxsize) {
217 return std::make_shared<WeakrefLRUCache>(cache_context_fn, fn, maxsize);
218 },
219 pybind11::arg("cache_context_fn"), pybind11::arg("fn"),
220 pybind11::arg("maxsize") = 2048);
221 }
222
223 } // namespace jax
224