xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/cross_trainer_cache.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_CROSS_TRAINER_CACHE_H_
16 #define TENSORFLOW_CORE_DATA_SERVICE_CROSS_TRAINER_CACHE_H_
17 
18 #include <cstddef>
19 #include <deque>
20 #include <functional>
21 #include <limits>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 
26 #include "absl/container/flat_hash_map.h"
27 #include "tensorflow/core/data/service/logging_utils.h"
28 #include "tensorflow/core/framework/metrics.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/status.h"
32 #include "tensorflow/core/platform/statusor.h"
33 #include "tensorflow/core/platform/thread_annotations.h"
34 
35 namespace tensorflow {
36 namespace data {
37 
38 // Sliding-window cache shared across concurrent trainers. Readers call `Get` to
39 // read elements they haven't read. After a trainer reads an element, it remains
40 // in the cache and the data is shared with other trainers. This is useful for
41 // datasets involving expensive computation, and multiple models use the same
42 // data for training. For example, for hyperparameter tuning.
43 //
44 // The cache progresses when a trainer that has consumed all elements in the
45 // cache requests additional data. It has a bounded size. Elements are garbage
46 // collected when the cache becomes full. Consequently, trainers read from a
47 // sliding window through the dataset and may not read the full dataset.
48 //
49 // The `CrossTrainerCache` class is thread-safe.
50 //
51 // Example usage:
52 //
53 //   // `InfiniteRange` returns 1, 2, 3, ... in the `GetNext` calls.
54 //   class InfiniteRange : public CachableSequence<int64_t> {
55 //    public:
56 //     StatusOr<int64_t> GetNext() override {
57 //       return next_++;
58 //     }
59 //
60 //     size_t GetElementSizeBytes(const int64_t& element) const override {
61 //       return sizeof(element);
62 //     }
63 //
64 //    private:
65 //     int64_t next_ = 1;
66 //   };
67 //
68 //   CrossTrainerCache<int64_t> cache(
69 //       /*max_cache_size_bytes=*/10 * (size_t{1} << 30),  // 10GB
70 //       std::make_unique<InfiniteRange>());
71 //
72 //   std::shared_ptr<int64_t> next;
73 //   TF_ASSIGN_OR_RETURN(next, cache.Get("Trainer 1"));  // Returns 1
74 //   TF_ASSIGN_OR_RETURN(next, cache.Get("Trainer 2"));  // Returns 1
75 //   TF_ASSIGN_OR_RETURN(next, cache.Get("Trainer 1"));  // Returns 2
76 //   TF_ASSIGN_OR_RETURN(next, cache.Get("Trainer 2"));  // Returns 2
77 
78 // To use the cache, the user needs to define a `CachableSequence` to generate
79 // an infinite sequence of data. It should implement a `GetNext` method to
80 // produce elements, and a `GetElementSizeBytes` method to estimate the element
81 // size in bytes.
82 template <class ElementType>
83 class CachableSequence {
84  public:
85   virtual ~CachableSequence() = default;
86 
87   // Returns the next element to be cached.
88   virtual StatusOr<ElementType> GetNext() = 0;
89 
90   // Returns the estimated size of the element in bytes.
91   virtual size_t GetElementSizeBytes(const ElementType&) const = 0;
92 };
93 
94 // Sliding-window cache shared across concurrent trainers.
95 template <class ElementType>
96 class CrossTrainerCache {
97  public:
98   // Creates a `CrossTrainerCache` with `max_cache_size_bytes` of memory budget.
99   // The cache should be able to hold at least one element, i.e.:
100   // REQUIRES: `max_cache_size_bytes >= max(GetElementSizeBytes(*))`
101   explicit CrossTrainerCache(
102       size_t max_cache_size_bytes,
103       std::unique_ptr<CachableSequence<ElementType>> cachable_sequence);
104   virtual ~CrossTrainerCache() = default;
105   CrossTrainerCache(const CrossTrainerCache&) = delete;
106   CrossTrainerCache& operator=(const CrossTrainerCache&) = delete;
107 
108   // Gets the next element for a trainer. A `trainer_id` identifies the trainer
109   // reading from the cache. A trainer reads the next element it hasn't read
110   // before. After a trainer reads data, the data is cached and reused by other
111   // trainers.
112   StatusOr<std::shared_ptr<const ElementType>> Get(
113       const std::string& trainer_id);
114 
115   // Cancels the cache with `status` and notifies the readers. After cancelling,
116   // all `Get` calls will return `status`.
117   // REQUIRES: !status.ok()
118   void Cancel(Status status);
119 
120   // Returns true if the cache has been cancelled.
121   bool IsCancelled() const;
122 
123  private:
124   struct CacheQueryResult {
125     std::shared_ptr<const ElementType> element;
126     bool cache_hit;
127   };
128 
129   // Returns the next element and metrics about this query.
130   StatusOr<CacheQueryResult> GetCacheQueryResult(const std::string& trainer_id);
131 
132   // Returns true if element is ready for `trainer_id`. An element is ready if
133   // other trainers have read the data and the data remains in the cache. If the
134   // data is not ready, one of the trainers need to extend the cache.
135   bool IsElementReady(const std::string& trainer_id);
136 
137   // Returns the absolute element index relative to the dataset (not relative to
138   // the cached elements).
139   size_t GetElementIndex(const std::string& trainer_id);
140 
141   // Returns the next element for `trainer_id`.
142   StatusOr<std::shared_ptr<const ElementType>> GetElement(
143       const std::string& trainer_id);
144 
145   // Reads a new element and writes it into the cache.
146   Status ExtendCache();
147 
148   // Frees old elements to keep the cache size below `max_cache_size_bytes_`.
149   // `new_element_size_bytes` is the size of the new element being inserted.
150   void FreeSpace(size_t new_element_size_bytes);
151 
152   // Records the cache hit rate and cache size.
153   void RecordMetrics(const CacheQueryResult& result);
154 
155   // Maximum cache size in bytes.
156   const size_t max_cache_size_bytes_;
157 
158   // The element sequence over which the sliding window cache operates.
159   std::unique_ptr<CachableSequence<ElementType>> cachable_sequence_;
160 
161   mutable mutex mu_;
162   mutable condition_variable cv_;
163 
164   // If `status_` is non-OK, the cache is cancelled, and all method calls will
165   // return this status.
166   Status status_ TF_GUARDED_BY(mu_) = OkStatus();
167 
168   // `cache_` stores the cached elements.
169   std::deque<std::shared_ptr<const ElementType>> cache_ TF_GUARDED_BY(mu_);
170   size_t cache_size_bytes_ TF_GUARDED_BY(mu_) = 0;
171   size_t cache_start_index_ TF_GUARDED_BY(mu_) = 0;
172 
173   // True if one thread is extending the cache.
174   bool extending_cache_ TF_GUARDED_BY(mu_) = false;
175 
176   // Maps trainer IDs to element indices. The indices are absolute indices
177   // within the dataset. The actual index to use with `cache_` would be
178   // `trainer_to_element_index_map_[trainer_id] - cache_start_index_`.
179   absl::flat_hash_map<std::string, size_t> trainer_to_element_index_map_
180       TF_GUARDED_BY(mu_);
181 };
182 
183 template <class ElementType>
CrossTrainerCache(size_t max_cache_size_bytes,std::unique_ptr<CachableSequence<ElementType>> cachable_sequence)184 CrossTrainerCache<ElementType>::CrossTrainerCache(
185     size_t max_cache_size_bytes,
186     std::unique_ptr<CachableSequence<ElementType>> cachable_sequence)
187     : max_cache_size_bytes_(max_cache_size_bytes),
188       cachable_sequence_(std::move(cachable_sequence)) {
189   DCHECK_GT(max_cache_size_bytes, 0)
190       << "CrossTrainerCache size must be greater than 0.";
191   VLOG(2) << "Initialized tf.data service cross-trainer cache with "
192           << FormatBytes(max_cache_size_bytes) << " of memory.";
193 }
194 
195 template <class ElementType>
196 StatusOr<std::shared_ptr<const ElementType>>
Get(const std::string & trainer_id)197 CrossTrainerCache<ElementType>::Get(const std::string& trainer_id)
198     TF_LOCKS_EXCLUDED(mu_) {
199   if (trainer_id.empty()) {
200     return errors::InvalidArgument(
201         "tf.data service cross-trainer cache requires a non-empty trainer ID.");
202   }
203 
204   TF_ASSIGN_OR_RETURN(CacheQueryResult result, GetCacheQueryResult(trainer_id));
205   RecordMetrics(result);
206   return result.element;
207 }
208 
209 template <class ElementType>
210 StatusOr<typename CrossTrainerCache<ElementType>::CacheQueryResult>
GetCacheQueryResult(const std::string & trainer_id)211 CrossTrainerCache<ElementType>::GetCacheQueryResult(
212     const std::string& trainer_id) {
213   bool should_extend_cache = false;
214   while (true) {
215     {
216       mutex_lock l(mu_);
217       TF_RETURN_IF_ERROR(status_);
218       if (IsElementReady(trainer_id)) {
219         TF_ASSIGN_OR_RETURN(std::shared_ptr<const ElementType> element,
220                             GetElement(trainer_id));
221         return CacheQueryResult{element,
222                                 /*is_cache_hit=*/!should_extend_cache};
223       }
224 
225       // Extends the cache or waits for another thread to extend the cache. When
226       // concurrent trainers wait for the next element, only one of them should
227       // extend the cache.
228       if (extending_cache_) {
229         should_extend_cache = false;
230         cv_.wait(l);
231       } else {
232         should_extend_cache = true;
233         extending_cache_ = true;
234       }
235     }
236 
237     if (should_extend_cache) {
238       Status s = ExtendCache();
239       mutex_lock l(mu_);
240       extending_cache_ = false;
241       cv_.notify_all();
242       TF_RETURN_IF_ERROR(s);
243     }
244   }
245 }
246 
247 template <class ElementType>
IsElementReady(const std::string & trainer_id)248 bool CrossTrainerCache<ElementType>::IsElementReady(
249     const std::string& trainer_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
250   return GetElementIndex(trainer_id) < cache_start_index_ + cache_.size();
251 }
252 
253 template <class ElementType>
254 StatusOr<std::shared_ptr<const ElementType>>
GetElement(const std::string & trainer_id)255 CrossTrainerCache<ElementType>::GetElement(const std::string& trainer_id)
256     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
257   size_t element_index = GetElementIndex(trainer_id);
258   if (element_index >= std::numeric_limits<size_t>::max()) {
259     return errors::Internal(
260         "tf.data service caching element index exceeds integer limit. Got ",
261         element_index);
262   }
263 
264   std::shared_ptr<const ElementType> result =
265       cache_[element_index - cache_start_index_];
266   trainer_to_element_index_map_[trainer_id] = element_index + 1;
267   return result;
268 }
269 
270 template <class ElementType>
GetElementIndex(const std::string & trainer_id)271 size_t CrossTrainerCache<ElementType>::GetElementIndex(
272     const std::string& trainer_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
273   size_t element_index = trainer_to_element_index_map_[trainer_id];
274   if (element_index < cache_start_index_) {
275     element_index = cache_start_index_;
276   }
277   return element_index;
278 }
279 
280 template <class ElementType>
ExtendCache()281 Status CrossTrainerCache<ElementType>::ExtendCache() TF_LOCKS_EXCLUDED(mu_) {
282   TF_ASSIGN_OR_RETURN(ElementType element, cachable_sequence_->GetNext());
283   size_t new_element_size_bytes =
284       cachable_sequence_->GetElementSizeBytes(element);
285   if (new_element_size_bytes > max_cache_size_bytes_) {
286     return errors::InvalidArgument(
287         "tf.data service element size is larger than cache size in bytes. Got ",
288         "element size: ", new_element_size_bytes,
289         " and cache size: ", max_cache_size_bytes_);
290   }
291 
292   mutex_lock l(mu_);
293   TF_RETURN_IF_ERROR(status_);
294   FreeSpace(new_element_size_bytes);
295   cache_.push_back(std::make_shared<ElementType>(std::move(element)));
296   cache_size_bytes_ += new_element_size_bytes;
297   return OkStatus();
298 }
299 
300 template <class ElementType>
FreeSpace(size_t new_element_size_bytes)301 void CrossTrainerCache<ElementType>::FreeSpace(size_t new_element_size_bytes)
302     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
303   size_t num_elements_discarded = 0;
304   while (!cache_.empty() &&
305          cache_size_bytes_ + new_element_size_bytes > max_cache_size_bytes_) {
306     size_t free_bytes =
307         cachable_sequence_->GetElementSizeBytes(*cache_.front());
308     cache_.pop_front();
309     cache_size_bytes_ -= free_bytes;
310     ++cache_start_index_;
311     ++num_elements_discarded;
312   }
313 
314   VLOG(3) << "Freed " << num_elements_discarded << " element(s) from "
315           << "tf.data service cross-trainer cache. Memory usage: "
316           << FormatBytes(cache_size_bytes_) << ".";
317 }
318 
319 template <class ElementType>
Cancel(Status status)320 void CrossTrainerCache<ElementType>::Cancel(Status status)
321     TF_LOCKS_EXCLUDED(mu_) {
322   DCHECK(!status.ok())
323       << "Cancelling CrossTrainerCache requires a non-OK status. Got "
324       << status;
325   VLOG(2) << "Cancel tf.data service cross-trainer cache with status "
326           << status;
327   mutex_lock l(mu_);
328   status_ = std::move(status);
329   cv_.notify_all();
330 }
331 
332 template <class ElementType>
IsCancelled()333 bool CrossTrainerCache<ElementType>::IsCancelled() const
334     TF_LOCKS_EXCLUDED(mu_) {
335   mutex_lock l(mu_);
336   return !status_.ok();
337 }
338 
339 template <class ElementType>
RecordMetrics(const CacheQueryResult & result)340 void CrossTrainerCache<ElementType>::RecordMetrics(
341     const CacheQueryResult& result) {
342   metrics::RecordTFDataServiceCrossTrainerCacheQuery(result.cache_hit);
343   size_t cache_size_bytes = 0;
344   {
345     mutex_lock l(mu_);
346     cache_size_bytes = cache_size_bytes_;
347   }
348   metrics::RecordTFDataServiceCrossTrainerCacheSizeBytes(cache_size_bytes);
349 }
350 
351 }  // namespace data
352 }  // namespace tensorflow
353 
354 #endif  // TENSORFLOW_CORE_DATA_SERVICE_CROSS_CLIENT_CACHE_H_
355