1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_CROSS_TRAINER_CACHE_H_
16 #define TENSORFLOW_CORE_DATA_SERVICE_CROSS_TRAINER_CACHE_H_
17
18 #include <cstddef>
19 #include <deque>
20 #include <functional>
21 #include <limits>
22 #include <memory>
23 #include <string>
24 #include <utility>
25
26 #include "absl/container/flat_hash_map.h"
27 #include "tensorflow/core/data/service/logging_utils.h"
28 #include "tensorflow/core/framework/metrics.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/status.h"
32 #include "tensorflow/core/platform/statusor.h"
33 #include "tensorflow/core/platform/thread_annotations.h"
34
35 namespace tensorflow {
36 namespace data {
37
38 // Sliding-window cache shared across concurrent trainers. Readers call `Get` to
39 // read elements they haven't read. After a trainer reads an element, it remains
40 // in the cache and the data is shared with other trainers. This is useful for
41 // datasets involving expensive computation, and multiple models use the same
42 // data for training. For example, for hyperparameter tuning.
43 //
44 // The cache progresses when a trainer that has consumed all elements in the
45 // cache requests additional data. It has a bounded size. Elements are garbage
46 // collected when the cache becomes full. Consequently, trainers read from a
47 // sliding window through the dataset and may not read the full dataset.
48 //
49 // The `CrossTrainerCache` class is thread-safe.
50 //
51 // Example usage:
52 //
53 // // `InfiniteRange` returns 1, 2, 3, ... in the `GetNext` calls.
54 // class InfiniteRange : public CachableSequence<int64_t> {
55 // public:
56 // StatusOr<int64_t> GetNext() override {
57 // return next_++;
58 // }
59 //
60 // size_t GetElementSizeBytes(const int64_t& element) const override {
61 // return sizeof(element);
62 // }
63 //
64 // private:
65 // int64_t next_ = 1;
66 // };
67 //
68 // CrossTrainerCache<int64_t> cache(
69 // /*max_cache_size_bytes=*/10 * (size_t{1} << 30), // 10GB
70 // std::make_unique<InfiniteRange>());
71 //
72 // std::shared_ptr<int64_t> next;
73 // TF_ASSIGN_OR_RETURN(next, cache.Get("Trainer 1")); // Returns 1
74 // TF_ASSIGN_OR_RETURN(next, cache.Get("Trainer 2")); // Returns 1
75 // TF_ASSIGN_OR_RETURN(next, cache.Get("Trainer 1")); // Returns 2
76 // TF_ASSIGN_OR_RETURN(next, cache.Get("Trainer 2")); // Returns 2
77
78 // To use the cache, the user needs to define a `CachableSequence` to generate
79 // an infinite sequence of data. It should implement a `GetNext` method to
80 // produce elements, and a `GetElementSizeBytes` method to estimate the element
81 // size in bytes.
82 template <class ElementType>
83 class CachableSequence {
84 public:
85 virtual ~CachableSequence() = default;
86
87 // Returns the next element to be cached.
88 virtual StatusOr<ElementType> GetNext() = 0;
89
90 // Returns the estimated size of the element in bytes.
91 virtual size_t GetElementSizeBytes(const ElementType&) const = 0;
92 };
93
94 // Sliding-window cache shared across concurrent trainers.
95 template <class ElementType>
96 class CrossTrainerCache {
97 public:
98 // Creates a `CrossTrainerCache` with `max_cache_size_bytes` of memory budget.
99 // The cache should be able to hold at least one element, i.e.:
100 // REQUIRES: `max_cache_size_bytes >= max(GetElementSizeBytes(*))`
101 explicit CrossTrainerCache(
102 size_t max_cache_size_bytes,
103 std::unique_ptr<CachableSequence<ElementType>> cachable_sequence);
104 virtual ~CrossTrainerCache() = default;
105 CrossTrainerCache(const CrossTrainerCache&) = delete;
106 CrossTrainerCache& operator=(const CrossTrainerCache&) = delete;
107
108 // Gets the next element for a trainer. A `trainer_id` identifies the trainer
109 // reading from the cache. A trainer reads the next element it hasn't read
110 // before. After a trainer reads data, the data is cached and reused by other
111 // trainers.
112 StatusOr<std::shared_ptr<const ElementType>> Get(
113 const std::string& trainer_id);
114
115 // Cancels the cache with `status` and notifies the readers. After cancelling,
116 // all `Get` calls will return `status`.
117 // REQUIRES: !status.ok()
118 void Cancel(Status status);
119
120 // Returns true if the cache has been cancelled.
121 bool IsCancelled() const;
122
123 private:
124 struct CacheQueryResult {
125 std::shared_ptr<const ElementType> element;
126 bool cache_hit;
127 };
128
129 // Returns the next element and metrics about this query.
130 StatusOr<CacheQueryResult> GetCacheQueryResult(const std::string& trainer_id);
131
132 // Returns true if element is ready for `trainer_id`. An element is ready if
133 // other trainers have read the data and the data remains in the cache. If the
134 // data is not ready, one of the trainers need to extend the cache.
135 bool IsElementReady(const std::string& trainer_id);
136
137 // Returns the absolute element index relative to the dataset (not relative to
138 // the cached elements).
139 size_t GetElementIndex(const std::string& trainer_id);
140
141 // Returns the next element for `trainer_id`.
142 StatusOr<std::shared_ptr<const ElementType>> GetElement(
143 const std::string& trainer_id);
144
145 // Reads a new element and writes it into the cache.
146 Status ExtendCache();
147
148 // Frees old elements to keep the cache size below `max_cache_size_bytes_`.
149 // `new_element_size_bytes` is the size of the new element being inserted.
150 void FreeSpace(size_t new_element_size_bytes);
151
152 // Records the cache hit rate and cache size.
153 void RecordMetrics(const CacheQueryResult& result);
154
155 // Maximum cache size in bytes.
156 const size_t max_cache_size_bytes_;
157
158 // The element sequence over which the sliding window cache operates.
159 std::unique_ptr<CachableSequence<ElementType>> cachable_sequence_;
160
161 mutable mutex mu_;
162 mutable condition_variable cv_;
163
164 // If `status_` is non-OK, the cache is cancelled, and all method calls will
165 // return this status.
166 Status status_ TF_GUARDED_BY(mu_) = OkStatus();
167
168 // `cache_` stores the cached elements.
169 std::deque<std::shared_ptr<const ElementType>> cache_ TF_GUARDED_BY(mu_);
170 size_t cache_size_bytes_ TF_GUARDED_BY(mu_) = 0;
171 size_t cache_start_index_ TF_GUARDED_BY(mu_) = 0;
172
173 // True if one thread is extending the cache.
174 bool extending_cache_ TF_GUARDED_BY(mu_) = false;
175
176 // Maps trainer IDs to element indices. The indices are absolute indices
177 // within the dataset. The actual index to use with `cache_` would be
178 // `trainer_to_element_index_map_[trainer_id] - cache_start_index_`.
179 absl::flat_hash_map<std::string, size_t> trainer_to_element_index_map_
180 TF_GUARDED_BY(mu_);
181 };
182
183 template <class ElementType>
CrossTrainerCache(size_t max_cache_size_bytes,std::unique_ptr<CachableSequence<ElementType>> cachable_sequence)184 CrossTrainerCache<ElementType>::CrossTrainerCache(
185 size_t max_cache_size_bytes,
186 std::unique_ptr<CachableSequence<ElementType>> cachable_sequence)
187 : max_cache_size_bytes_(max_cache_size_bytes),
188 cachable_sequence_(std::move(cachable_sequence)) {
189 DCHECK_GT(max_cache_size_bytes, 0)
190 << "CrossTrainerCache size must be greater than 0.";
191 VLOG(2) << "Initialized tf.data service cross-trainer cache with "
192 << FormatBytes(max_cache_size_bytes) << " of memory.";
193 }
194
195 template <class ElementType>
196 StatusOr<std::shared_ptr<const ElementType>>
Get(const std::string & trainer_id)197 CrossTrainerCache<ElementType>::Get(const std::string& trainer_id)
198 TF_LOCKS_EXCLUDED(mu_) {
199 if (trainer_id.empty()) {
200 return errors::InvalidArgument(
201 "tf.data service cross-trainer cache requires a non-empty trainer ID.");
202 }
203
204 TF_ASSIGN_OR_RETURN(CacheQueryResult result, GetCacheQueryResult(trainer_id));
205 RecordMetrics(result);
206 return result.element;
207 }
208
209 template <class ElementType>
210 StatusOr<typename CrossTrainerCache<ElementType>::CacheQueryResult>
GetCacheQueryResult(const std::string & trainer_id)211 CrossTrainerCache<ElementType>::GetCacheQueryResult(
212 const std::string& trainer_id) {
213 bool should_extend_cache = false;
214 while (true) {
215 {
216 mutex_lock l(mu_);
217 TF_RETURN_IF_ERROR(status_);
218 if (IsElementReady(trainer_id)) {
219 TF_ASSIGN_OR_RETURN(std::shared_ptr<const ElementType> element,
220 GetElement(trainer_id));
221 return CacheQueryResult{element,
222 /*is_cache_hit=*/!should_extend_cache};
223 }
224
225 // Extends the cache or waits for another thread to extend the cache. When
226 // concurrent trainers wait for the next element, only one of them should
227 // extend the cache.
228 if (extending_cache_) {
229 should_extend_cache = false;
230 cv_.wait(l);
231 } else {
232 should_extend_cache = true;
233 extending_cache_ = true;
234 }
235 }
236
237 if (should_extend_cache) {
238 Status s = ExtendCache();
239 mutex_lock l(mu_);
240 extending_cache_ = false;
241 cv_.notify_all();
242 TF_RETURN_IF_ERROR(s);
243 }
244 }
245 }
246
247 template <class ElementType>
IsElementReady(const std::string & trainer_id)248 bool CrossTrainerCache<ElementType>::IsElementReady(
249 const std::string& trainer_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
250 return GetElementIndex(trainer_id) < cache_start_index_ + cache_.size();
251 }
252
253 template <class ElementType>
254 StatusOr<std::shared_ptr<const ElementType>>
GetElement(const std::string & trainer_id)255 CrossTrainerCache<ElementType>::GetElement(const std::string& trainer_id)
256 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
257 size_t element_index = GetElementIndex(trainer_id);
258 if (element_index >= std::numeric_limits<size_t>::max()) {
259 return errors::Internal(
260 "tf.data service caching element index exceeds integer limit. Got ",
261 element_index);
262 }
263
264 std::shared_ptr<const ElementType> result =
265 cache_[element_index - cache_start_index_];
266 trainer_to_element_index_map_[trainer_id] = element_index + 1;
267 return result;
268 }
269
270 template <class ElementType>
GetElementIndex(const std::string & trainer_id)271 size_t CrossTrainerCache<ElementType>::GetElementIndex(
272 const std::string& trainer_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
273 size_t element_index = trainer_to_element_index_map_[trainer_id];
274 if (element_index < cache_start_index_) {
275 element_index = cache_start_index_;
276 }
277 return element_index;
278 }
279
280 template <class ElementType>
ExtendCache()281 Status CrossTrainerCache<ElementType>::ExtendCache() TF_LOCKS_EXCLUDED(mu_) {
282 TF_ASSIGN_OR_RETURN(ElementType element, cachable_sequence_->GetNext());
283 size_t new_element_size_bytes =
284 cachable_sequence_->GetElementSizeBytes(element);
285 if (new_element_size_bytes > max_cache_size_bytes_) {
286 return errors::InvalidArgument(
287 "tf.data service element size is larger than cache size in bytes. Got ",
288 "element size: ", new_element_size_bytes,
289 " and cache size: ", max_cache_size_bytes_);
290 }
291
292 mutex_lock l(mu_);
293 TF_RETURN_IF_ERROR(status_);
294 FreeSpace(new_element_size_bytes);
295 cache_.push_back(std::make_shared<ElementType>(std::move(element)));
296 cache_size_bytes_ += new_element_size_bytes;
297 return OkStatus();
298 }
299
300 template <class ElementType>
FreeSpace(size_t new_element_size_bytes)301 void CrossTrainerCache<ElementType>::FreeSpace(size_t new_element_size_bytes)
302 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
303 size_t num_elements_discarded = 0;
304 while (!cache_.empty() &&
305 cache_size_bytes_ + new_element_size_bytes > max_cache_size_bytes_) {
306 size_t free_bytes =
307 cachable_sequence_->GetElementSizeBytes(*cache_.front());
308 cache_.pop_front();
309 cache_size_bytes_ -= free_bytes;
310 ++cache_start_index_;
311 ++num_elements_discarded;
312 }
313
314 VLOG(3) << "Freed " << num_elements_discarded << " element(s) from "
315 << "tf.data service cross-trainer cache. Memory usage: "
316 << FormatBytes(cache_size_bytes_) << ".";
317 }
318
319 template <class ElementType>
Cancel(Status status)320 void CrossTrainerCache<ElementType>::Cancel(Status status)
321 TF_LOCKS_EXCLUDED(mu_) {
322 DCHECK(!status.ok())
323 << "Cancelling CrossTrainerCache requires a non-OK status. Got "
324 << status;
325 VLOG(2) << "Cancel tf.data service cross-trainer cache with status "
326 << status;
327 mutex_lock l(mu_);
328 status_ = std::move(status);
329 cv_.notify_all();
330 }
331
332 template <class ElementType>
IsCancelled()333 bool CrossTrainerCache<ElementType>::IsCancelled() const
334 TF_LOCKS_EXCLUDED(mu_) {
335 mutex_lock l(mu_);
336 return !status_.ok();
337 }
338
339 template <class ElementType>
RecordMetrics(const CacheQueryResult & result)340 void CrossTrainerCache<ElementType>::RecordMetrics(
341 const CacheQueryResult& result) {
342 metrics::RecordTFDataServiceCrossTrainerCacheQuery(result.cache_hit);
343 size_t cache_size_bytes = 0;
344 {
345 mutex_lock l(mu_);
346 cache_size_bytes = cache_size_bytes_;
347 }
348 metrics::RecordTFDataServiceCrossTrainerCacheSizeBytes(cache_size_bytes);
349 }
350
351 } // namespace data
352 } // namespace tensorflow
353
354 #endif // TENSORFLOW_CORE_DATA_SERVICE_CROSS_CLIENT_CACHE_H_
355