1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_THREAD_SAFE_BUFFER_H_ 16 #define TENSORFLOW_CORE_DATA_SERVICE_THREAD_SAFE_BUFFER_H_ 17 18 #include <deque> 19 #include <utility> 20 21 #include "tensorflow/core/platform/macros.h" 22 #include "tensorflow/core/platform/mutex.h" 23 #include "tensorflow/core/platform/status.h" 24 #include "tensorflow/core/platform/statusor.h" 25 26 namespace tensorflow { 27 namespace data { 28 29 // A thread-safe bounded buffer with cancellation support. 30 template <class T> 31 class ThreadSafeBuffer final { 32 public: 33 // Creates a buffer with the specified `buffer_size`. 34 // REQUIRES: buffer_size > 0 35 explicit ThreadSafeBuffer(size_t buffer_size); 36 37 // Gets the next element. Blocks if the buffer is empty. Returns an error if 38 // a non-OK status was pushed or the buffer has been cancelled. 39 StatusOr<T> Pop(); 40 41 // Writes the next element. Blocks if the buffer is full. Returns an error if 42 // the buffer has been cancelled. 43 Status Push(StatusOr<T> value); 44 45 // Cancels the buffer with `status` and notifies waiting threads. After 46 // cancelling, all `Push` and `Pop` calls will return `status`. 47 // REQUIRES: !status.ok() 48 void Cancel(Status status); 49 50 private: 51 const size_t buffer_size_; 52 53 mutex mu_; 54 condition_variable ready_to_pop_; 55 condition_variable ready_to_push_; 56 std::deque<StatusOr<T>> results_ TF_GUARDED_BY(mu_); 57 Status status_ TF_GUARDED_BY(mu_) = OkStatus(); 58 59 TF_DISALLOW_COPY_AND_ASSIGN(ThreadSafeBuffer); 60 }; 61 62 template <class T> ThreadSafeBuffer(size_t buffer_size)63ThreadSafeBuffer<T>::ThreadSafeBuffer(size_t buffer_size) 64 : buffer_size_(buffer_size) { 65 DCHECK_GT(buffer_size, 0) 66 << "ThreadSafeBuffer must have a positive buffer size. Got " 67 << buffer_size << "."; 68 } 69 70 template <class T> Pop()71StatusOr<T> ThreadSafeBuffer<T>::Pop() { 72 mutex_lock l(mu_); 73 while (status_.ok() && results_.empty()) { 74 ready_to_pop_.wait(l); 75 } 76 if (!status_.ok()) { 77 return status_; 78 } 79 StatusOr<T> result = std::move(results_.front()); 80 results_.pop_front(); 81 ready_to_push_.notify_one(); 82 return result; 83 } 84 85 template <class T> Push(StatusOr<T> value)86Status ThreadSafeBuffer<T>::Push(StatusOr<T> value) { 87 mutex_lock l(mu_); 88 while (status_.ok() && results_.size() >= buffer_size_) { 89 ready_to_push_.wait(l); 90 } 91 if (!status_.ok()) { 92 return status_; 93 } 94 results_.push_back(std::move(value)); 95 ready_to_pop_.notify_one(); 96 return OkStatus(); 97 } 98 99 template <class T> Cancel(Status status)100void ThreadSafeBuffer<T>::Cancel(Status status) { 101 DCHECK(!status.ok()) 102 << "Cancelling ThreadSafeBuffer requires a non-OK status. Got " << status; 103 mutex_lock l(mu_); 104 status_ = std::move(status); 105 ready_to_push_.notify_all(); 106 ready_to_pop_.notify_all(); 107 } 108 109 } // namespace data 110 } // namespace tensorflow 111 112 #endif // TENSORFLOW_CORE_DATA_SERVICE_THREAD_SAFE_BUFFER_H_ 113