1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_ 16 #define TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_ 17 18 #include <memory> 19 #include <vector> 20 21 #include "tensorflow/core/data/service/common.pb.h" 22 #include "tensorflow/core/data/service/cross_trainer_cache.h" 23 #include "tensorflow/core/data/service/data_transfer.h" 24 #include "tensorflow/core/data/service/thread_safe_buffer.h" 25 #include "tensorflow/core/data/service/worker.pb.h" 26 #include "tensorflow/core/data/standalone.h" 27 #include "tensorflow/core/platform/macros.h" 28 #include "tensorflow/core/platform/mutex.h" 29 #include "tensorflow/core/platform/status.h" 30 #include "tensorflow/core/platform/statusor.h" 31 #include "tensorflow/core/platform/thread_annotations.h" 32 #include "tensorflow/core/protobuf/service_config.pb.h" 33 34 namespace tensorflow { 35 namespace data { 36 37 // Iterator over a task's elements. 38 class TaskIterator { 39 public: 40 virtual ~TaskIterator() = default; 41 // If the iterator is not yet exhausted, `GetNext` stores the next element in 42 // `element` and sets `end_of_sequence` to `false`. Otherwise, sets 43 // `end_of_sequence to `true`. 44 virtual Status GetNext(std::vector<Tensor>& element, 45 bool& end_of_sequence) = 0; 46 // Reports the cardinality of the dataset that created this iterator. 47 virtual int64_t Cardinality() const = 0; 48 }; 49 50 // Implementation of TaskIterator wrapping a standalone iterator. 51 class StandaloneTaskIterator : public TaskIterator { 52 public: 53 // `dataset` should be the dataset that created `iterator`. 54 // StandaloneTaskIterator takes ownership of the dataset to ensures it 55 // lives as long as `iterator`. 56 StandaloneTaskIterator(std::unique_ptr<standalone::Dataset> dataset, 57 std::unique_ptr<standalone::Iterator> iterator); 58 Status GetNext(std::vector<Tensor>& element, bool& end_of_sequence) override; 59 int64_t Cardinality() const override; 60 61 private: 62 std::unique_ptr<standalone::Dataset> dataset_; 63 std::unique_ptr<standalone::Iterator> iterator_; 64 }; 65 66 // Interface for providing elements to task consumers. 67 class TaskRunner { 68 public: 69 // Creates a `TaskRunner` and stores it in `out`. 70 static Status Create(const experimental::WorkerConfig& worker_config, 71 const TaskDef& task_def, 72 std::unique_ptr<TaskIterator> iterator, 73 std::unique_ptr<TaskRunner>& out); 74 virtual ~TaskRunner() = default; 75 // Gets the next element for the given request. 76 virtual Status GetNext(const GetElementRequest& req, 77 GetElementResult& result) = 0; 78 // Cancels in-progress `GetNext` requests. 79 virtual void Cancel() = 0; 80 }; 81 82 // A task runner which provides elements on a first-come first-served basis. 83 // It does not consider which consumer is making the request. 84 class FirstComeFirstServedTaskRunner : public TaskRunner { 85 public: 86 explicit FirstComeFirstServedTaskRunner( 87 std::unique_ptr<TaskIterator> iterator); 88 ~FirstComeFirstServedTaskRunner() override; 89 90 // Gets the next element. It may block if the element is not ready yet. 91 Status GetNext(const GetElementRequest& req, 92 GetElementResult& result) override; 93 Status GetNext(GetElementResult& result); 94 95 void Cancel() override; 96 97 private: 98 // Function to continually prefetch the next element. Returns an error if the 99 // task has been cancelled. 100 Status PrefetchFn(); 101 102 // Runs `PrefetchFn` on a dedicated thread. 103 void RunPrefetchThread(); 104 105 // Gets the next element from the input iterator. 106 StatusOr<GetElementResult> GetNextFromInputIterator() TF_LOCKS_EXCLUDED(mu_); 107 108 mutex mu_; 109 std::unique_ptr<TaskIterator> iterator_ TF_GUARDED_BY(mu_); 110 int64_t element_index_ TF_GUARDED_BY(mu_) = 0; 111 112 ThreadSafeBuffer<GetElementResult> buffer_; 113 std::unique_ptr<Thread> prefetch_thread_; 114 115 TF_DISALLOW_COPY_AND_ASSIGN(FirstComeFirstServedTaskRunner); 116 }; 117 118 // A task runner which prefetches elements on a first-come first-served basis 119 // and caches elements in a sliding-window `CrossTrainerCache`. The cache has a 120 // bounded size and progresses when a trainer that has consumed all elements in 121 // the cache. Trainers read from a sliding window of the dataset and may not 122 // read the full dataset. 123 class CachingTaskRunner : public TaskRunner { 124 public: 125 explicit CachingTaskRunner(std::unique_ptr<TaskIterator> iterator, 126 size_t max_cache_size_bytes); 127 ~CachingTaskRunner() override; 128 129 // Gets the next element from the cross-trainer cache, blocking if the data is 130 // not ready. 131 // REQUIRES: !req.trainer_id().empty() 132 Status GetNext(const GetElementRequest& req, 133 GetElementResult& result) override; 134 135 // Cancel the task runner. After cancelling, all the `GetNext` calls will 136 // return a Cancelled status. 137 void Cancel() override; 138 139 private: 140 // The `GetElementResultSequence` generates a sequence of elements from the 141 // `FirstComeFirstServedTaskRunner`. It is used for the `CrossTrainerCache` to 142 // generate cached elements. 143 class GetElementResultSequence : public CachableSequence<GetElementResult> { 144 public: 145 explicit GetElementResultSequence( 146 FirstComeFirstServedTaskRunner& fcfs_task_runner); 147 StatusOr<GetElementResult> GetNext() override; 148 size_t GetElementSizeBytes(const GetElementResult& element) const override; 149 150 private: 151 FirstComeFirstServedTaskRunner& fcfs_task_runner_; 152 }; 153 154 FirstComeFirstServedTaskRunner fcfs_task_runner_; 155 CrossTrainerCache<GetElementResult> cache_; 156 157 TF_DISALLOW_COPY_AND_ASSIGN(CachingTaskRunner); 158 }; 159 160 // An element produced by a task. 161 struct Element { ElementElement162 explicit Element(std::vector<Tensor>&& components, int64_t index) 163 : components(components), index(index) {} 164 // The components of the element. 165 std::vector<Tensor> components; 166 // The element's index within the task, e.g. 0 for the first element produced 167 // by the task, 1 for the second element, etc. 168 int64_t index; 169 }; 170 171 // Thread for prefetching a round worth of elements. 172 class PrefetchThread { 173 public: 174 explicit PrefetchThread(std::unique_ptr<TaskIterator> iterator, 175 int64_t round_size); 176 ~PrefetchThread(); 177 // Runs the prefetch thread. It runs until an error is encountered or the 178 // destructor is called. 179 void Run(); 180 // Fills `out` with a round of data. Waits for up to `wait_us` microseconds 181 // before giving up and returning with `out` empty. A negative `wait_us` 182 // signals to wait indefinitely. 183 Status FillBuffer(int64_t wait_us, 184 std::vector<std::unique_ptr<Element>>& out); 185 // Returns the status for any failures encountered by the prefetch thread. 186 Status GetStatus(); 187 188 private: 189 const std::unique_ptr<TaskIterator> iterator_; 190 const int64_t round_size_; 191 mutex mu_; 192 int64_t index_ TF_GUARDED_BY(mu_) = 0; 193 // Buffered results for the next round. 194 std::vector<std::unique_ptr<Element>> buffer_ TF_GUARDED_BY(mu_); 195 // The status if the prefetch thread fails. 196 Status status_ TF_GUARDED_BY(mu_) = OkStatus(); 197 // Condition variable notified when elements are added to or removed from 198 // `buffer_`, or when `status_` is changed. 199 condition_variable cv_; 200 bool cancelled_ TF_GUARDED_BY(mu_) = false; 201 // Thread which constantly tries to fill `buffer_` up with 202 // `num_consumers` elements. 203 std::unique_ptr<Thread> thread_; 204 }; 205 206 // A task runner which enforces round-robin order for consuming a task's 207 // elements. `RoundRobinTaskRunner` provides elements in a series of "rounds". 208 // In each successive round, the runner waits to receive requests from all 209 // consumers. These requests are blocked until all requests arrive. Once all 210 // requests arrive, the runner hands out elements to consumers in order of their 211 // consumer indices. 212 // 213 // Consumers are expected to successively request consecutive element indices, 214 // starting at 0. The same element can be requested multiple times by the same 215 // consumer, as long as the consumer hasn't yet requested the next element (at 216 // the start of each round we discard elements from the previous round). 217 // 218 // If the worker restarts mid-round, a situation arises where some consumers 219 // are requesting element index `n` while others are requesting element index 220 // `n + 1`. To remedy this, the first round after restart may be a partial 221 // round, where we only serve elements to consumers requesting data for element 222 // index `n`, blocking other consumers until the second round. 223 class RoundRobinTaskRunner : public TaskRunner { 224 public: 225 RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator, 226 int64_t num_consumers, string worker_address); 227 228 Status GetNext(const GetElementRequest& req, 229 GetElementResult& result) override; 230 void Cancel() override; 231 232 private: 233 // Prepares a full round of data. `wait_us` indicates how long to wait before 234 // skipping if a full round of data is not yet ready. 235 Status PrepareFullRound(int64_t wait_us) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 236 // Prepares a partial round to get consumers back in sync. 237 Status PreparePartialRound() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 238 Status ValidateRequest(const GetElementRequest& req); 239 // Prepares data for the next round, blocking until the round is ready to 240 // start. 241 Status PrepareRound(const GetElementRequest& req); 242 const int64_t num_consumers_; 243 const string worker_address_; 244 mutex mu_; 245 bool cancelled_ TF_GUARDED_BY(mu_) = false; 246 // Condition variable notified whenever we start a new round of round-robin. 247 condition_variable new_round_cv_; 248 // Outstanding requests, indexed by round number and then consumer index. 249 absl::flat_hash_map<int64_t, 250 absl::flat_hash_map<int64_t, const GetElementRequest*>> 251 requests_ TF_GUARDED_BY(mu_); 252 // Index of the first round we plan to serve. At startup, this is the minimum 253 // of all requested element indices. 254 int64_t first_round_ TF_GUARDED_BY(mu_) = kint64max; 255 int64_t current_round_ TF_GUARDED_BY(mu_) = -1; 256 bool round_skipped_ TF_GUARDED_BY(mu_) = false; 257 // Buffered results for the current round. 258 std::vector<std::unique_ptr<Element>> buffer_ TF_GUARDED_BY(mu_); 259 // Thread which constantly tries to prepare `num_consumers` elements for the 260 // next round. 261 PrefetchThread prefetch_thread_; 262 }; 263 264 } // namespace data 265 } // namespace tensorflow 266 267 #endif // TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_ 268