xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/weakref_lru_cache.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/weakref_lru_cache.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "absl/cleanup/cleanup.h"
22 #include "absl/synchronization/notification.h"
23 #include "pybind11/pybind11.h"
24 #include "tensorflow/compiler/xla/pjrt/lru_cache.h"
25 
26 namespace jax {
27 
28 class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
29  public:
30   struct Key {
31     pybind11::object context;
32     pybind11::args args;
33     pybind11::kwargs kwargs;
34 
operator ==jax::WeakrefLRUCache::Key35     bool operator==(const Key& other) const {
36       if (!context.equal(other.context)) return false;
37       if (!args.equal(other.args)) return false;
38       if (!kwargs.equal(other.kwargs)) return false;
39       return true;
40     }
41 
42     template <typename H>
AbslHashValue(H h,const Key & key)43     friend H AbslHashValue(H h, const Key& key) {
44       h = H::combine(std::move(h), pybind11::hash(key.context));
45       h = H::combine(std::move(h), pybind11::hash(key.args));
46       h = H::combine(std::move(h), key.kwargs.size());
47       for (auto& kv : key.kwargs) {
48         h = H::combine(std::move(h), pybind11::hash(kv.first));
49         h = H::combine(std::move(h), pybind11::hash(kv.second));
50       }
51       return h;
52     }
53   };
54 
55   struct CacheEntry {
56     bool has_result = false;
57     pybind11::object result;
58     absl::Notification completed;
59   };
60 
61   struct CacheInfo {
62     int64_t hits;
63     int64_t misses;
64     int64_t maxsize;
65     int64_t currsize;
66   };
67 
68   struct UnboundWeakrefCacheEntry {
69     pybind11::handle object;
70     WeakrefLRUCache* cache;
71     size_t cached_hash;
72   };
73 
74   struct WeakrefCacheEntry {
75     pybind11::weakref weakref;
76     size_t cached_hash;
77   };
78 
79   struct WeakrefKeyHash {
80     using is_transparent = void;
81 
operator ()jax::WeakrefLRUCache::WeakrefKeyHash82     size_t operator()(const UnboundWeakrefCacheEntry& v) const {
83       return v.cached_hash;
84     }
operator ()jax::WeakrefLRUCache::WeakrefKeyHash85     size_t operator()(const WeakrefCacheEntry& v) const {
86       return v.cached_hash;
87     }
88   };
89 
90   struct WeakrefKeyEq {
91     using is_transparent = void;
operator ()jax::WeakrefLRUCache::WeakrefKeyEq92     bool operator()(const WeakrefCacheEntry& lhs,
93                     const WeakrefCacheEntry& rhs) const {
94       return lhs.weakref.equal(rhs.weakref);
95     }
operator ()jax::WeakrefLRUCache::WeakrefKeyEq96     bool operator()(const WeakrefCacheEntry& lhs,
97                     const UnboundWeakrefCacheEntry& rhs) const {
98       PyObject* obj = PyWeakref_GET_OBJECT(lhs.weakref.ptr());
99       if (obj == Py_None) {
100         return false;
101       }
102       return pybind11::reinterpret_borrow<pybind11::object>(obj).equal(
103           rhs.object);
104     }
105   };
106 
107   using Cache = xla::LRUCache<Key, std::shared_ptr<CacheEntry>>;
WeakrefLRUCache(pybind11::function cache_context_fn,pybind11::function fn,int64_t maxsize)108   WeakrefLRUCache(pybind11::function cache_context_fn, pybind11::function fn,
109                   int64_t maxsize)
110       : cache_context_fn_(cache_context_fn), fn_(fn), lru_list_(maxsize) {}
111 
GetCache(const UnboundWeakrefCacheEntry & key)112   std::shared_ptr<Cache> GetCache(const UnboundWeakrefCacheEntry& key) {
113     auto it = entries_.find(key);
114     if (it != entries_.end()) {
115       return (it->second);
116     }
117     pybind11::weakref weakref(
118         key.object, pybind11::cpp_function([this_weak = weak_from_this(),
119                                             cached_hash = key.cached_hash](
120                                                pybind11::handle weakref) {
121           auto cache = this_weak.lock();
122           if (cache == nullptr) {
123             return;
124           }
125           cache->entries_.erase(WeakrefCacheEntry{
126               pybind11::reinterpret_borrow<pybind11::weakref>(weakref),
127               cached_hash});
128         }));
129     return (entries_
130                 .emplace(WeakrefCacheEntry{std::move(weakref), key.cached_hash},
131                          std::make_shared<Cache>(&lru_list_))
132                 .first->second);
133   }
134 
Call(pybind11::object weakref_key,pybind11::args args,pybind11::kwargs kwargs)135   pybind11::object Call(pybind11::object weakref_key, pybind11::args args,
136                         pybind11::kwargs kwargs) {
137     pybind11::object context = cache_context_fn_();
138     std::shared_ptr<Cache> cache_ptr = GetCache(UnboundWeakrefCacheEntry{
139         weakref_key, this, static_cast<size_t>(pybind11::hash(weakref_key))});
140     Cache& cache = *cache_ptr;
141     ++total_queries_;
142 
143     bool inserted = false;
144     Key key{context, args, kwargs};
145     auto entry = cache.GetOrCreateIfAbsent(key, [&inserted](const Key& key) {
146       inserted = true;
147       return std::make_shared<CacheEntry>();
148     });
149     if (!entry->completed.HasBeenNotified()) {
150       if (inserted) {
151         ++misses_;
152         absl::Cleanup notify = [&] { entry->completed.Notify(); };
153         entry->result = fn_(weakref_key, *args, **kwargs);
154         entry->has_result = true;
155       } else {
156         pybind11::gil_scoped_release release;
157         entry->completed.WaitForNotification();
158       }
159     }
160 
161     if (entry->has_result) {
162       return entry->result;
163     } else {
164       ++misses_;
165       return fn_(weakref_key, *args, **kwargs);
166     }
167   }
GetCacheInfo() const168   CacheInfo GetCacheInfo() const {
169     CacheInfo result;
170     result.hits = total_queries_ - misses_;
171     result.misses = misses_;
172     result.maxsize = lru_list_.Capacity();
173     result.currsize = lru_list_.Size();
174     return result;
175   }
Clear()176   void Clear() {
177     total_queries_ = misses_ = 0;
178     entries_.clear();
179   }
180 
181   pybind11::function cache_context_fn_;
182   pybind11::function fn_;
183   Cache::LRUList lru_list_;
184   absl::node_hash_map<WeakrefCacheEntry, std::shared_ptr<Cache>, WeakrefKeyHash,
185                       WeakrefKeyEq>
186       entries_;
187   int64_t misses_ = 0;
188   int64_t total_queries_ = 0;
189 };
190 
191 namespace {
192 namespace py = ::pybind11;
193 }  // namespace
194 
BuildWeakrefLRUCacheAPI(pybind11::module & m)195 void BuildWeakrefLRUCacheAPI(pybind11::module& m) {
196   auto weakref_lru_cache =
197       py::class_<WeakrefLRUCache, std::shared_ptr<WeakrefLRUCache>>(
198           m, "WeakrefLRUCache")
199           .def("__call__", &WeakrefLRUCache::Call)
200           .def("cache_info", &WeakrefLRUCache::GetCacheInfo)
201           .def("cache_clear", &WeakrefLRUCache::Clear);
202   py::class_<WeakrefLRUCache::CacheInfo>(weakref_lru_cache,
203                                          "WeakrefLRUCacheInfo")
204       .def_readonly("hits", &WeakrefLRUCache::CacheInfo::hits)
205       .def_readonly("misses", &WeakrefLRUCache::CacheInfo::misses)
206       .def_readonly("maxsize", &WeakrefLRUCache::CacheInfo::maxsize)
207       .def_readonly("currsize", &WeakrefLRUCache::CacheInfo::currsize)
208       .def("__repr__", [](WeakrefLRUCache::CacheInfo& info) {
209         return absl::StrCat(
210             "WeakrefLRUCache(hits=", info.hits, ", misses=", info.misses,
211             ", maxsize=", info.maxsize, ", currsize=", info.currsize, ")");
212       });
213   m.def(
214       "weakref_lru_cache",
215       [](pybind11::function cache_context_fn, pybind11::function fn,
216          int64_t maxsize) {
217         return std::make_shared<WeakrefLRUCache>(cache_context_fn, fn, maxsize);
218       },
219       pybind11::arg("cache_context_fn"), pybind11::arg("fn"),
220       pybind11::arg("maxsize") = 2048);
221 }
222 
223 }  // namespace jax
224