xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/thread_safe_buffer.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_THREAD_SAFE_BUFFER_H_
16 #define TENSORFLOW_CORE_DATA_SERVICE_THREAD_SAFE_BUFFER_H_
17 
18 #include <deque>
19 #include <utility>
20 
21 #include "tensorflow/core/platform/macros.h"
22 #include "tensorflow/core/platform/mutex.h"
23 #include "tensorflow/core/platform/status.h"
24 #include "tensorflow/core/platform/statusor.h"
25 
26 namespace tensorflow {
27 namespace data {
28 
29 // A thread-safe bounded buffer with cancellation support.
30 template <class T>
31 class ThreadSafeBuffer final {
32  public:
33   // Creates a buffer with the specified `buffer_size`.
34   // REQUIRES: buffer_size > 0
35   explicit ThreadSafeBuffer(size_t buffer_size);
36 
37   // Gets the next element. Blocks if the buffer is empty. Returns an error if
38   // a non-OK status was pushed or the buffer has been cancelled.
39   StatusOr<T> Pop();
40 
41   // Writes the next element. Blocks if the buffer is full. Returns an error if
42   // the buffer has been cancelled.
43   Status Push(StatusOr<T> value);
44 
45   // Cancels the buffer with `status` and notifies waiting threads. After
46   // cancelling, all `Push` and `Pop` calls will return `status`.
47   // REQUIRES: !status.ok()
48   void Cancel(Status status);
49 
50  private:
51   const size_t buffer_size_;
52 
53   mutex mu_;
54   condition_variable ready_to_pop_;
55   condition_variable ready_to_push_;
56   std::deque<StatusOr<T>> results_ TF_GUARDED_BY(mu_);
57   Status status_ TF_GUARDED_BY(mu_) = OkStatus();
58 
59   TF_DISALLOW_COPY_AND_ASSIGN(ThreadSafeBuffer);
60 };
61 
62 template <class T>
ThreadSafeBuffer(size_t buffer_size)63 ThreadSafeBuffer<T>::ThreadSafeBuffer(size_t buffer_size)
64     : buffer_size_(buffer_size) {
65   DCHECK_GT(buffer_size, 0)
66       << "ThreadSafeBuffer must have a positive buffer size. Got "
67       << buffer_size << ".";
68 }
69 
70 template <class T>
Pop()71 StatusOr<T> ThreadSafeBuffer<T>::Pop() {
72   mutex_lock l(mu_);
73   while (status_.ok() && results_.empty()) {
74     ready_to_pop_.wait(l);
75   }
76   if (!status_.ok()) {
77     return status_;
78   }
79   StatusOr<T> result = std::move(results_.front());
80   results_.pop_front();
81   ready_to_push_.notify_one();
82   return result;
83 }
84 
85 template <class T>
Push(StatusOr<T> value)86 Status ThreadSafeBuffer<T>::Push(StatusOr<T> value) {
87   mutex_lock l(mu_);
88   while (status_.ok() && results_.size() >= buffer_size_) {
89     ready_to_push_.wait(l);
90   }
91   if (!status_.ok()) {
92     return status_;
93   }
94   results_.push_back(std::move(value));
95   ready_to_pop_.notify_one();
96   return OkStatus();
97 }
98 
99 template <class T>
Cancel(Status status)100 void ThreadSafeBuffer<T>::Cancel(Status status) {
101   DCHECK(!status.ok())
102       << "Cancelling ThreadSafeBuffer requires a non-OK status. Got " << status;
103   mutex_lock l(mu_);
104   status_ = std::move(status);
105   ready_to_push_.notify_all();
106   ready_to_pop_.notify_all();
107 }
108 
109 }  // namespace data
110 }  // namespace tensorflow
111 
112 #endif  // TENSORFLOW_CORE_DATA_SERVICE_THREAD_SAFE_BUFFER_H_
113