1 // Copyright 2019 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #ifndef marl_scheduler_h
16 #define marl_scheduler_h
17
18 #include "containers.h"
19 #include "debug.h"
20 #include "deprecated.h"
21 #include "export.h"
22 #include "memory.h"
23 #include "mutex.h"
24 #include "sanitizers.h"
25 #include "task.h"
26 #include "thread.h"
27 #include "thread_local.h"
28
29 #include <array>
30 #include <atomic>
31 #include <chrono>
32 #include <condition_variable>
33 #include <functional>
34 #include <thread>
35
36 namespace marl {
37
38 class OSFiber;
39
40 // Scheduler asynchronously processes Tasks.
41 // A scheduler can be bound to one or more threads using the bind() method.
42 // Once bound to a thread, that thread can call marl::schedule() to enqueue
43 // work tasks to be executed asynchronously.
44 // Scheduler are initially constructed in single-threaded mode.
45 // Call setWorkerThreadCount() to spawn dedicated worker threads.
46 class Scheduler {
47 class Worker;
48
49 public:
50 using TimePoint = std::chrono::system_clock::time_point;
51 using Predicate = std::function<bool()>;
52 using ThreadInitializer = std::function<void(int workerId)>;
53
54 // Config holds scheduler configuration settings that can be passed to the
55 // Scheduler constructor.
56 struct Config {
57 static constexpr size_t DefaultFiberStackSize = 1024 * 1024;
58
59 // Per-worker-thread settings.
60 struct WorkerThread {
61 // Total number of dedicated worker threads to spawn for the scheduler.
62 int count = 0;
63
64 // Initializer function to call after thread creation and before any work
65 // is run by the thread.
66 ThreadInitializer initializer;
67
68 // Thread affinity policy to use for worker threads.
69 std::shared_ptr<Thread::Affinity::Policy> affinityPolicy;
70 };
71
72 WorkerThread workerThread;
73
74 // Memory allocator to use for the scheduler and internal allocations.
75 Allocator* allocator = Allocator::Default;
76
77 // Size of each fiber stack. This may be rounded up to the nearest
78 // allocation granularity for the given platform.
79 size_t fiberStackSize = DefaultFiberStackSize;
80
81 // allCores() returns a Config with a worker thread for each of the logical
82 // cpus available to the process.
83 MARL_EXPORT
84 static Config allCores();
85
86 // Fluent setters that return this Config so set calls can be chained.
87 MARL_NO_EXPORT inline Config& setAllocator(Allocator*);
88 MARL_NO_EXPORT inline Config& setFiberStackSize(size_t);
89 MARL_NO_EXPORT inline Config& setWorkerThreadCount(int);
90 MARL_NO_EXPORT inline Config& setWorkerThreadInitializer(
91 const ThreadInitializer&);
92 MARL_NO_EXPORT inline Config& setWorkerThreadAffinityPolicy(
93 const std::shared_ptr<Thread::Affinity::Policy>&);
94 };
95
96 // Constructor.
97 MARL_EXPORT
98 Scheduler(const Config&);
99
100 // Destructor.
101 // Blocks until the scheduler is unbound from all threads before returning.
102 MARL_EXPORT
103 ~Scheduler();
104
105 // get() returns the scheduler bound to the current thread.
106 MARL_EXPORT
107 static Scheduler* get();
108
109 // bind() binds this scheduler to the current thread.
110 // There must be no existing scheduler bound to the thread prior to calling.
111 MARL_EXPORT
112 void bind();
113
114 // unbind() unbinds the scheduler currently bound to the current thread.
115 // There must be an existing scheduler bound to the thread prior to calling.
116 // unbind() flushes any enqueued tasks on the single-threaded worker before
117 // returning.
118 MARL_EXPORT
119 static void unbind();
120
121 // enqueue() queues the task for asynchronous execution.
122 MARL_EXPORT
123 void enqueue(Task&& task);
124
125 // config() returns the Config that was used to build the scheduler.
126 MARL_EXPORT
127 const Config& config() const;
128
129 // Fibers expose methods to perform cooperative multitasking and are
130 // automatically created by the Scheduler.
131 //
132 // The currently executing Fiber can be obtained by calling Fiber::current().
133 //
134 // When execution becomes blocked, yield() can be called to suspend execution
135 // of the fiber and start executing other pending work. Once the block has
136 // been lifted, schedule() can be called to reschedule the Fiber on the same
137 // thread that previously executed it.
138 class Fiber {
139 public:
140 // current() returns the currently executing fiber, or nullptr if called
141 // without a bound scheduler.
142 MARL_EXPORT
143 static Fiber* current();
144
145 // wait() suspends execution of this Fiber until the Fiber is woken up with
146 // a call to notify() and the predicate pred returns true.
147 // If the predicate pred does not return true when notify() is called, then
148 // the Fiber is automatically re-suspended, and will need to be woken with
149 // another call to notify().
150 // While the Fiber is suspended, the scheduler thread may continue executing
151 // other tasks.
152 // lock must be locked before calling, and is unlocked by wait() just before
153 // the Fiber is suspended, and re-locked before the fiber is resumed. lock
154 // will be locked before wait() returns.
155 // pred will be always be called with the lock held.
156 // wait() must only be called on the currently executing fiber.
157 MARL_EXPORT
158 void wait(marl::lock& lock, const Predicate& pred);
159
160 // wait() suspends execution of this Fiber until the Fiber is woken up with
161 // a call to notify() and the predicate pred returns true, or sometime after
162 // the timeout is reached.
163 // If the predicate pred does not return true when notify() is called, then
164 // the Fiber is automatically re-suspended, and will need to be woken with
165 // another call to notify() or will be woken sometime after the timeout is
166 // reached.
167 // While the Fiber is suspended, the scheduler thread may continue executing
168 // other tasks.
169 // lock must be locked before calling, and is unlocked by wait() just before
170 // the Fiber is suspended, and re-locked before the fiber is resumed. lock
171 // will be locked before wait() returns.
172 // pred will be always be called with the lock held.
173 // wait() must only be called on the currently executing fiber.
174 template <typename Clock, typename Duration>
175 MARL_NO_EXPORT inline bool wait(
176 marl::lock& lock,
177 const std::chrono::time_point<Clock, Duration>& timeout,
178 const Predicate& pred);
179
180 // wait() suspends execution of this Fiber until the Fiber is woken up with
181 // a call to notify().
182 // While the Fiber is suspended, the scheduler thread may continue executing
183 // other tasks.
184 // wait() must only be called on the currently executing fiber.
185 //
186 // Warning: Unlike wait() overloads that take a lock and predicate, this
187 // form of wait() offers no safety for notify() signals that occur before
188 // the fiber is suspended, when signalling between different threads. In
189 // this scenario you may deadlock. For this reason, it is only ever
190 // recommended to use this overload if you can guarantee that the calls to
191 // wait() and notify() are made by the same thread.
192 //
193 // Use with extreme caution.
194 MARL_NO_EXPORT inline void wait();
195
196 // wait() suspends execution of this Fiber until the Fiber is woken up with
197 // a call to notify(), or sometime after the timeout is reached.
198 // While the Fiber is suspended, the scheduler thread may continue executing
199 // other tasks.
200 // wait() must only be called on the currently executing fiber.
201 //
202 // Warning: Unlike wait() overloads that take a lock and predicate, this
203 // form of wait() offers no safety for notify() signals that occur before
204 // the fiber is suspended, when signalling between different threads. For
205 // this reason, it is only ever recommended to use this overload if you can
206 // guarantee that the calls to wait() and notify() are made by the same
207 // thread.
208 //
209 // Use with extreme caution.
210 template <typename Clock, typename Duration>
211 MARL_NO_EXPORT inline bool wait(
212 const std::chrono::time_point<Clock, Duration>& timeout);
213
214 // notify() reschedules the suspended Fiber for execution.
215 // notify() is usually only called when the predicate for one or more wait()
216 // calls will likely return true.
217 MARL_EXPORT
218 void notify();
219
220 // id is the thread-unique identifier of the Fiber.
221 uint32_t const id;
222
223 private:
224 friend class Allocator;
225 friend class Scheduler;
226
227 enum class State {
228 // Idle: the Fiber is currently unused, and sits in Worker::idleFibers,
229 // ready to be recycled.
230 Idle,
231
232 // Yielded: the Fiber is currently blocked on a wait() call with no
233 // timeout.
234 Yielded,
235
236 // Waiting: the Fiber is currently blocked on a wait() call with a
237 // timeout. The fiber is stilling in the Worker::Work::waiting queue.
238 Waiting,
239
240 // Queued: the Fiber is currently queued for execution in the
241 // Worker::Work::fibers queue.
242 Queued,
243
244 // Running: the Fiber is currently executing.
245 Running,
246 };
247
248 Fiber(Allocator::unique_ptr<OSFiber>&&, uint32_t id);
249
250 // switchTo() switches execution to the given fiber.
251 // switchTo() must only be called on the currently executing fiber.
252 void switchTo(Fiber*);
253
254 // create() constructs and returns a new fiber with the given identifier,
255 // stack size and func that will be executed when switched to.
256 static Allocator::unique_ptr<Fiber> create(
257 Allocator* allocator,
258 uint32_t id,
259 size_t stackSize,
260 const std::function<void()>& func);
261
262 // createFromCurrentThread() constructs and returns a new fiber with the
263 // given identifier for the current thread.
264 static Allocator::unique_ptr<Fiber> createFromCurrentThread(
265 Allocator* allocator,
266 uint32_t id);
267
268 // toString() returns a string representation of the given State.
269 // Used for debugging.
270 static const char* toString(State state);
271
272 Allocator::unique_ptr<OSFiber> const impl;
273 Worker* const worker;
274 State state = State::Running; // Guarded by Worker's work.mutex.
275 };
276
277 private:
278 Scheduler(const Scheduler&) = delete;
279 Scheduler(Scheduler&&) = delete;
280 Scheduler& operator=(const Scheduler&) = delete;
281 Scheduler& operator=(Scheduler&&) = delete;
282
283 // Maximum number of worker threads.
284 static constexpr size_t MaxWorkerThreads = 256;
285
286 // WaitingFibers holds all the fibers waiting on a timeout.
287 struct WaitingFibers {
288 inline WaitingFibers(Allocator*);
289
290 // operator bool() returns true iff there are any wait fibers.
291 inline operator bool() const;
292
293 // take() returns the next fiber that has exceeded its timeout, or nullptr
294 // if there are no fibers that have yet exceeded their timeouts.
295 inline Fiber* take(const TimePoint& timeout);
296
297 // next() returns the timepoint of the next fiber to timeout.
298 // next() can only be called if operator bool() returns true.
299 inline TimePoint next() const;
300
301 // add() adds another fiber and timeout to the list of waiting fibers.
302 inline void add(const TimePoint& timeout, Fiber* fiber);
303
304 // erase() removes the fiber from the waiting list.
305 inline void erase(Fiber* fiber);
306
307 // contains() returns true if fiber is waiting.
308 inline bool contains(Fiber* fiber) const;
309
310 private:
311 struct Timeout {
312 TimePoint timepoint;
313 Fiber* fiber;
314 inline bool operator<(const Timeout&) const;
315 };
316 containers::set<Timeout, std::less<Timeout>> timeouts;
317 containers::unordered_map<Fiber*, TimePoint> fibers;
318 };
319
320 // TODO: Implement a queue that recycles elements to reduce number of
321 // heap allocations.
322 using TaskQueue = containers::deque<Task>;
323 using FiberQueue = containers::deque<Fiber*>;
324 using FiberSet = containers::unordered_set<Fiber*>;
325
326 // Workers execute Tasks on a single thread.
327 // Once a task is started, it may yield to other tasks on the same Worker.
328 // Tasks are always resumed by the same Worker.
329 class Worker {
330 public:
331 enum class Mode {
332 // Worker will spawn a background thread to process tasks.
333 MultiThreaded,
334
335 // Worker will execute tasks whenever it yields.
336 SingleThreaded,
337 };
338
339 Worker(Scheduler* scheduler, Mode mode, uint32_t id);
340
341 // start() begins execution of the worker.
342 void start() EXCLUDES(work.mutex);
343
344 // stop() ceases execution of the worker, blocking until all pending
345 // tasks have fully finished.
346 void stop() EXCLUDES(work.mutex);
347
348 // wait() suspends execution of the current task until the predicate pred
349 // returns true or the optional timeout is reached.
350 // See Fiber::wait() for more information.
351 MARL_EXPORT
352 bool wait(marl::lock& lock, const TimePoint* timeout, const Predicate& pred)
353 EXCLUDES(work.mutex);
354
355 // wait() suspends execution of the current task until the fiber is
356 // notified, or the optional timeout is reached.
357 // See Fiber::wait() for more information.
358 MARL_EXPORT
359 bool wait(const TimePoint* timeout) EXCLUDES(work.mutex);
360
361 // suspend() suspends the currently executing Fiber until the fiber is
362 // woken with a call to enqueue(Fiber*), or automatically sometime after the
363 // optional timeout.
364 void suspend(const TimePoint* timeout) REQUIRES(work.mutex);
365
366 // enqueue(Fiber*) enqueues resuming of a suspended fiber.
367 void enqueue(Fiber* fiber) EXCLUDES(work.mutex);
368
369 // enqueue(Task&&) enqueues a new, unstarted task.
370 void enqueue(Task&& task) EXCLUDES(work.mutex);
371
372 // tryLock() attempts to lock the worker for task enqueuing.
373 // If the lock was successful then true is returned, and the caller must
374 // call enqueueAndUnlock().
375 bool tryLock() EXCLUDES(work.mutex) TRY_ACQUIRE(true, work.mutex);
376
377 // enqueueAndUnlock() enqueues the task and unlocks the worker.
378 // Must only be called after a call to tryLock() which returned true.
379 // _Releases_lock_(work.mutex)
380 void enqueueAndUnlock(Task&& task) REQUIRES(work.mutex) RELEASE(work.mutex);
381
382 // runUntilShutdown() processes all tasks and fibers until there are no more
383 // and shutdown is true, upon runUntilShutdown() returns.
384 void runUntilShutdown() REQUIRES(work.mutex);
385
386 // steal() attempts to steal a Task from the worker for another worker.
387 // Returns true if a task was taken and assigned to out, otherwise false.
388 bool steal(Task& out) EXCLUDES(work.mutex);
389
390 // getCurrent() returns the Worker currently bound to the current
391 // thread.
392 static inline Worker* getCurrent();
393
394 // getCurrentFiber() returns the Fiber currently being executed.
395 inline Fiber* getCurrentFiber() const;
396
397 // Unique identifier of the Worker.
398 const uint32_t id;
399
400 private:
401 // run() is the task processing function for the worker.
402 // run() processes tasks until stop() is called.
403 void run() REQUIRES(work.mutex);
404
405 // createWorkerFiber() creates a new fiber that when executed calls
406 // run().
407 Fiber* createWorkerFiber() REQUIRES(work.mutex);
408
409 // switchToFiber() switches execution to the given fiber. The fiber
410 // must belong to this worker.
411 void switchToFiber(Fiber*) REQUIRES(work.mutex);
412
413 // runUntilIdle() executes all pending tasks and then returns.
414 void runUntilIdle() REQUIRES(work.mutex);
415
416 // waitForWork() blocks until new work is available, potentially calling
417 // spinForWork().
418 void waitForWork() REQUIRES(work.mutex);
419
420 // spinForWorkAndLock() attempts to steal work from another Worker, and keeps
421 // the thread awake for a short duration. This reduces overheads of
422 // frequently putting the thread to sleep and re-waking. It locks the mutex
423 // before returning so that a stolen task cannot be re-stolen by other workers.
424 void spinForWorkAndLock() ACQUIRE(work.mutex);
425
426 // enqueueFiberTimeouts() enqueues all the fibers that have finished
427 // waiting.
428 void enqueueFiberTimeouts() REQUIRES(work.mutex);
429
430 inline void changeFiberState(Fiber* fiber,
431 Fiber::State from,
432 Fiber::State to) const REQUIRES(work.mutex);
433
434 inline void setFiberState(Fiber* fiber, Fiber::State to) const
435 REQUIRES(work.mutex);
436
437 // Work holds tasks and fibers that are enqueued on the Worker.
438 struct Work {
439 inline Work(Allocator*);
440
441 std::atomic<uint64_t> num = {0}; // tasks.size() + fibers.size()
442 GUARDED_BY(mutex) uint64_t numBlockedFibers = 0;
443 GUARDED_BY(mutex) TaskQueue tasks;
444 GUARDED_BY(mutex) FiberQueue fibers;
445 GUARDED_BY(mutex) WaitingFibers waiting;
446 GUARDED_BY(mutex) bool notifyAdded = true;
447 std::condition_variable added;
448 marl::mutex mutex;
449
450 template <typename F>
451 inline void wait(F&&) REQUIRES(mutex);
452 };
453
454 // https://en.wikipedia.org/wiki/Xorshift
455 class FastRnd {
456 public:
operator()457 inline uint64_t operator()() {
458 x ^= x << 13;
459 x ^= x >> 7;
460 x ^= x << 17;
461 return x;
462 }
463
464 private:
465 uint64_t x = std::chrono::system_clock::now().time_since_epoch().count();
466 };
467
468 // The current worker bound to the current thread.
469 MARL_DECLARE_THREAD_LOCAL(Worker*, current);
470
471 Mode const mode;
472 Scheduler* const scheduler;
473 Allocator::unique_ptr<Fiber> mainFiber;
474 Fiber* currentFiber = nullptr;
475 Thread thread;
476 Work work;
477 FiberSet idleFibers; // Fibers that have completed which can be reused.
478 containers::vector<Allocator::unique_ptr<Fiber>, 16>
479 workerFibers; // All fibers created by this worker.
480 FastRnd rng;
481 bool shutdown = false;
482 };
483
484 // stealWork() attempts to steal a task from the worker with the given id.
485 // Returns true if a task was stolen and assigned to out, otherwise false.
486 bool stealWork(Worker* thief, uint64_t from, Task& out);
487
488 // onBeginSpinning() is called when a Worker calls spinForWork().
489 // The scheduler will prioritize this worker for new tasks to try to prevent
490 // it going to sleep.
491 void onBeginSpinning(int workerId);
492
493 // setBound() sets the scheduler bound to the current thread.
494 static void setBound(Scheduler* scheduler);
495
496 // The scheduler currently bound to the current thread.
497 MARL_DECLARE_THREAD_LOCAL(Scheduler*, bound);
498
499 // The immutable configuration used to build the scheduler.
500 const Config cfg;
501
502 std::array<std::atomic<int>, MaxWorkerThreads> spinningWorkers;
503 std::atomic<unsigned int> nextSpinningWorkerIdx = {0x8000000};
504
505 std::atomic<unsigned int> nextEnqueueIndex = {0};
506 std::array<Worker*, MaxWorkerThreads> workerThreads;
507
508 struct SingleThreadedWorkers {
509 inline SingleThreadedWorkers(Allocator*);
510
511 using WorkerByTid =
512 containers::unordered_map<std::thread::id,
513 Allocator::unique_ptr<Worker>>;
514 marl::mutex mutex;
515 GUARDED_BY(mutex) std::condition_variable unbind;
516 GUARDED_BY(mutex) WorkerByTid byTid;
517 };
518 SingleThreadedWorkers singleThreadedWorkers;
519 };
520
521 ////////////////////////////////////////////////////////////////////////////////
522 // Scheduler::Config
523 ////////////////////////////////////////////////////////////////////////////////
setAllocator(Allocator * alloc)524 Scheduler::Config& Scheduler::Config::setAllocator(Allocator* alloc) {
525 allocator = alloc;
526 return *this;
527 }
528
setFiberStackSize(size_t size)529 Scheduler::Config& Scheduler::Config::setFiberStackSize(size_t size) {
530 fiberStackSize = size;
531 return *this;
532 }
533
setWorkerThreadCount(int count)534 Scheduler::Config& Scheduler::Config::setWorkerThreadCount(int count) {
535 workerThread.count = count;
536 return *this;
537 }
538
setWorkerThreadInitializer(const ThreadInitializer & initializer)539 Scheduler::Config& Scheduler::Config::setWorkerThreadInitializer(
540 const ThreadInitializer& initializer) {
541 workerThread.initializer = initializer;
542 return *this;
543 }
544
setWorkerThreadAffinityPolicy(const std::shared_ptr<Thread::Affinity::Policy> & policy)545 Scheduler::Config& Scheduler::Config::setWorkerThreadAffinityPolicy(
546 const std::shared_ptr<Thread::Affinity::Policy>& policy) {
547 workerThread.affinityPolicy = policy;
548 return *this;
549 }
550
551 ////////////////////////////////////////////////////////////////////////////////
552 // Scheduler::Fiber
553 ////////////////////////////////////////////////////////////////////////////////
554 template <typename Clock, typename Duration>
wait(marl::lock & lock,const std::chrono::time_point<Clock,Duration> & timeout,const Predicate & pred)555 bool Scheduler::Fiber::wait(
556 marl::lock& lock,
557 const std::chrono::time_point<Clock, Duration>& timeout,
558 const Predicate& pred) {
559 using ToDuration = typename TimePoint::duration;
560 using ToClock = typename TimePoint::clock;
561 auto tp = std::chrono::time_point_cast<ToDuration, ToClock>(timeout);
562 return worker->wait(lock, &tp, pred);
563 }
564
wait()565 void Scheduler::Fiber::wait() {
566 worker->wait(nullptr);
567 }
568
569 template <typename Clock, typename Duration>
wait(const std::chrono::time_point<Clock,Duration> & timeout)570 bool Scheduler::Fiber::wait(
571 const std::chrono::time_point<Clock, Duration>& timeout) {
572 using ToDuration = typename TimePoint::duration;
573 using ToClock = typename TimePoint::clock;
574 auto tp = std::chrono::time_point_cast<ToDuration, ToClock>(timeout);
575 return worker->wait(&tp);
576 }
577
getCurrent()578 Scheduler::Worker* Scheduler::Worker::getCurrent() {
579 return Worker::current;
580 }
581
getCurrentFiber()582 Scheduler::Fiber* Scheduler::Worker::getCurrentFiber() const {
583 return currentFiber;
584 }
585
586 // schedule() schedules the task T to be asynchronously called using the
587 // currently bound scheduler.
schedule(Task && t)588 inline void schedule(Task&& t) {
589 MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
590 auto scheduler = Scheduler::get();
591 scheduler->enqueue(std::move(t));
592 }
593
594 // schedule() schedules the function f to be asynchronously called with the
595 // given arguments using the currently bound scheduler.
596 template <typename Function, typename... Args>
schedule(Function && f,Args &&...args)597 inline void schedule(Function&& f, Args&&... args) {
598 MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
599 auto scheduler = Scheduler::get();
600 scheduler->enqueue(
601 Task(std::bind(std::forward<Function>(f), std::forward<Args>(args)...)));
602 }
603
604 // schedule() schedules the function f to be asynchronously called using the
605 // currently bound scheduler.
606 template <typename Function>
schedule(Function && f)607 inline void schedule(Function&& f) {
608 MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
609 auto scheduler = Scheduler::get();
610 scheduler->enqueue(Task(std::forward<Function>(f)));
611 }
612
613 } // namespace marl
614
615 #endif // marl_scheduler_h
616