xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/refcounting_hash_map.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_REFCOUNTING_HASH_MAP_H_
17 #define TENSORFLOW_COMPILER_XLA_REFCOUNTING_HASH_MAP_H_
18 
19 #include <functional>
20 #include <memory>
21 
22 #include "absl/base/thread_annotations.h"
23 #include "absl/container/node_hash_map.h"
24 #include "absl/synchronization/mutex.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 
27 namespace xla {
28 
29 // RefcountingHashMap is an "eager, thread-safe cache".
30 //
31 // Given a key k you can retrieve a shared_ptr to a value v.  If k is not
32 // already in the map, we construct a new V; if it is already in the map, we'll
33 // return the existing v.  Once all shared_ptrs are destroyed, the entry is
34 // removed from the map.
35 //
36 // This class is thread-safe.
37 //
38 // Word to the wise: You might want an erase() function here that removes a
39 // value from the map but leaves existing shared_ptrs intact.  My experience is,
40 // this is extremely complicated to implement correctly.
41 template <typename K, typename V>
42 class RefcountingHashMap {
43  public:
44   // Default-constructs new values.
45   RefcountingHashMap() = default;
46 
47   // Not copyable or movable because this contains internal pointers (namely,
48   // instances of Deleter contain pointers to `this` and into `map_`).
49   RefcountingHashMap(const RefcountingHashMap&) = delete;
50   RefcountingHashMap(RefcountingHashMap&&) = delete;
51   RefcountingHashMap& operator=(const RefcountingHashMap&) = delete;
52   RefcountingHashMap& operator=(RefcountingHashMap&&) = delete;
53 
54   // Gets the value for the given key.
55   //
56   // If the map doesn't contain a live value for the key, constructs one
57   // using `value_factory`.
GetOrCreateIfAbsent(const K & key,const std::function<std::unique_ptr<V> (const K &)> & value_factory)58   std::shared_ptr<V> GetOrCreateIfAbsent(
59       const K& key,
60       const std::function<std::unique_ptr<V>(const K&)>& value_factory) {
61     absl::MutexLock lock(&mu_);
62     auto it = map_.find(key);
63     if (it != map_.end()) {
64       // We ensure that the entry has not expired in case deleter was running
65       // when we have entered this block.
66       if (std::shared_ptr<V> value = it->second.lock()) {
67         return value;
68       }
69     }
70 
71     // Create entry in the map and then set its value, so the value can
72     // contain a pointer back into the map.
73     it = map_.emplace(key, std::weak_ptr<V>()).first;
74     std::shared_ptr<V> value(value_factory(key).release(),
75                              Deleter{it->first, *this});
76     it->second = value;  // Set the weak ptr to the shared ptr.
77     return value;
78   }
79 
80  private:
81   struct Deleter {
82     const K& key;  // Points into parent->map_.
83     RefcountingHashMap& parent;
84 
operatorDeleter85     void operator()(V* v) {
86       delete v;
87       absl::MutexLock lock(&parent.mu_);
88       // We must check if that the entry is still expired in case the value was
89       // replaced while the deleter was running.
90       auto it = parent.map_.find(key);
91       if (it != parent.map_.end() && it->second.expired()) {
92         parent.map_.erase(it);
93       }
94     }
95   };
96 
97   absl::Mutex mu_;
98   absl::node_hash_map<K, std::weak_ptr<V>> map_ ABSL_GUARDED_BY(mu_);
99 };
100 
101 }  // namespace xla
102 
103 #endif  // TENSORFLOW_COMPILER_XLA_REFCOUNTING_HASH_MAP_H_
104