xref: /aosp_15_r20/external/swiftshader/third_party/marl/include/marl/scheduler.h (revision 03ce13f70fcc45d86ee91b7ee4cab1936a95046e)
1 // Copyright 2019 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef marl_scheduler_h
16 #define marl_scheduler_h
17 
18 #include "containers.h"
19 #include "debug.h"
20 #include "deprecated.h"
21 #include "export.h"
22 #include "memory.h"
23 #include "mutex.h"
24 #include "sanitizers.h"
25 #include "task.h"
26 #include "thread.h"
27 #include "thread_local.h"
28 
29 #include <array>
30 #include <atomic>
31 #include <chrono>
32 #include <condition_variable>
33 #include <functional>
34 #include <thread>
35 
36 namespace marl {
37 
38 class OSFiber;
39 
40 // Scheduler asynchronously processes Tasks.
41 // A scheduler can be bound to one or more threads using the bind() method.
42 // Once bound to a thread, that thread can call marl::schedule() to enqueue
43 // work tasks to be executed asynchronously.
44 // Scheduler are initially constructed in single-threaded mode.
45 // Call setWorkerThreadCount() to spawn dedicated worker threads.
46 class Scheduler {
47   class Worker;
48 
49  public:
50   using TimePoint = std::chrono::system_clock::time_point;
51   using Predicate = std::function<bool()>;
52   using ThreadInitializer = std::function<void(int workerId)>;
53 
54   // Config holds scheduler configuration settings that can be passed to the
55   // Scheduler constructor.
56   struct Config {
57     static constexpr size_t DefaultFiberStackSize = 1024 * 1024;
58 
59     // Per-worker-thread settings.
60     struct WorkerThread {
61       // Total number of dedicated worker threads to spawn for the scheduler.
62       int count = 0;
63 
64       // Initializer function to call after thread creation and before any work
65       // is run by the thread.
66       ThreadInitializer initializer;
67 
68       // Thread affinity policy to use for worker threads.
69       std::shared_ptr<Thread::Affinity::Policy> affinityPolicy;
70     };
71 
72     WorkerThread workerThread;
73 
74     // Memory allocator to use for the scheduler and internal allocations.
75     Allocator* allocator = Allocator::Default;
76 
77     // Size of each fiber stack. This may be rounded up to the nearest
78     // allocation granularity for the given platform.
79     size_t fiberStackSize = DefaultFiberStackSize;
80 
81     // allCores() returns a Config with a worker thread for each of the logical
82     // cpus available to the process.
83     MARL_EXPORT
84     static Config allCores();
85 
86     // Fluent setters that return this Config so set calls can be chained.
87     MARL_NO_EXPORT inline Config& setAllocator(Allocator*);
88     MARL_NO_EXPORT inline Config& setFiberStackSize(size_t);
89     MARL_NO_EXPORT inline Config& setWorkerThreadCount(int);
90     MARL_NO_EXPORT inline Config& setWorkerThreadInitializer(
91         const ThreadInitializer&);
92     MARL_NO_EXPORT inline Config& setWorkerThreadAffinityPolicy(
93         const std::shared_ptr<Thread::Affinity::Policy>&);
94   };
95 
96   // Constructor.
97   MARL_EXPORT
98   Scheduler(const Config&);
99 
100   // Destructor.
101   // Blocks until the scheduler is unbound from all threads before returning.
102   MARL_EXPORT
103   ~Scheduler();
104 
105   // get() returns the scheduler bound to the current thread.
106   MARL_EXPORT
107   static Scheduler* get();
108 
109   // bind() binds this scheduler to the current thread.
110   // There must be no existing scheduler bound to the thread prior to calling.
111   MARL_EXPORT
112   void bind();
113 
114   // unbind() unbinds the scheduler currently bound to the current thread.
115   // There must be an existing scheduler bound to the thread prior to calling.
116   // unbind() flushes any enqueued tasks on the single-threaded worker before
117   // returning.
118   MARL_EXPORT
119   static void unbind();
120 
121   // enqueue() queues the task for asynchronous execution.
122   MARL_EXPORT
123   void enqueue(Task&& task);
124 
125   // config() returns the Config that was used to build the scheduler.
126   MARL_EXPORT
127   const Config& config() const;
128 
129   // Fibers expose methods to perform cooperative multitasking and are
130   // automatically created by the Scheduler.
131   //
132   // The currently executing Fiber can be obtained by calling Fiber::current().
133   //
134   // When execution becomes blocked, yield() can be called to suspend execution
135   // of the fiber and start executing other pending work. Once the block has
136   // been lifted, schedule() can be called to reschedule the Fiber on the same
137   // thread that previously executed it.
138   class Fiber {
139    public:
140     // current() returns the currently executing fiber, or nullptr if called
141     // without a bound scheduler.
142     MARL_EXPORT
143     static Fiber* current();
144 
145     // wait() suspends execution of this Fiber until the Fiber is woken up with
146     // a call to notify() and the predicate pred returns true.
147     // If the predicate pred does not return true when notify() is called, then
148     // the Fiber is automatically re-suspended, and will need to be woken with
149     // another call to notify().
150     // While the Fiber is suspended, the scheduler thread may continue executing
151     // other tasks.
152     // lock must be locked before calling, and is unlocked by wait() just before
153     // the Fiber is suspended, and re-locked before the fiber is resumed. lock
154     // will be locked before wait() returns.
155     // pred will be always be called with the lock held.
156     // wait() must only be called on the currently executing fiber.
157     MARL_EXPORT
158     void wait(marl::lock& lock, const Predicate& pred);
159 
160     // wait() suspends execution of this Fiber until the Fiber is woken up with
161     // a call to notify() and the predicate pred returns true, or sometime after
162     // the timeout is reached.
163     // If the predicate pred does not return true when notify() is called, then
164     // the Fiber is automatically re-suspended, and will need to be woken with
165     // another call to notify() or will be woken sometime after the timeout is
166     // reached.
167     // While the Fiber is suspended, the scheduler thread may continue executing
168     // other tasks.
169     // lock must be locked before calling, and is unlocked by wait() just before
170     // the Fiber is suspended, and re-locked before the fiber is resumed. lock
171     // will be locked before wait() returns.
172     // pred will be always be called with the lock held.
173     // wait() must only be called on the currently executing fiber.
174     template <typename Clock, typename Duration>
175     MARL_NO_EXPORT inline bool wait(
176         marl::lock& lock,
177         const std::chrono::time_point<Clock, Duration>& timeout,
178         const Predicate& pred);
179 
180     // wait() suspends execution of this Fiber until the Fiber is woken up with
181     // a call to notify().
182     // While the Fiber is suspended, the scheduler thread may continue executing
183     // other tasks.
184     // wait() must only be called on the currently executing fiber.
185     //
186     // Warning: Unlike wait() overloads that take a lock and predicate, this
187     // form of wait() offers no safety for notify() signals that occur before
188     // the fiber is suspended, when signalling between different threads. In
189     // this scenario you may deadlock. For this reason, it is only ever
190     // recommended to use this overload if you can guarantee that the calls to
191     // wait() and notify() are made by the same thread.
192     //
193     // Use with extreme caution.
194     MARL_NO_EXPORT inline void wait();
195 
196     // wait() suspends execution of this Fiber until the Fiber is woken up with
197     // a call to notify(), or sometime after the timeout is reached.
198     // While the Fiber is suspended, the scheduler thread may continue executing
199     // other tasks.
200     // wait() must only be called on the currently executing fiber.
201     //
202     // Warning: Unlike wait() overloads that take a lock and predicate, this
203     // form of wait() offers no safety for notify() signals that occur before
204     // the fiber is suspended, when signalling between different threads. For
205     // this reason, it is only ever recommended to use this overload if you can
206     // guarantee that the calls to wait() and notify() are made by the same
207     // thread.
208     //
209     // Use with extreme caution.
210     template <typename Clock, typename Duration>
211     MARL_NO_EXPORT inline bool wait(
212         const std::chrono::time_point<Clock, Duration>& timeout);
213 
214     // notify() reschedules the suspended Fiber for execution.
215     // notify() is usually only called when the predicate for one or more wait()
216     // calls will likely return true.
217     MARL_EXPORT
218     void notify();
219 
220     // id is the thread-unique identifier of the Fiber.
221     uint32_t const id;
222 
223    private:
224     friend class Allocator;
225     friend class Scheduler;
226 
227     enum class State {
228       // Idle: the Fiber is currently unused, and sits in Worker::idleFibers,
229       // ready to be recycled.
230       Idle,
231 
232       // Yielded: the Fiber is currently blocked on a wait() call with no
233       // timeout.
234       Yielded,
235 
236       // Waiting: the Fiber is currently blocked on a wait() call with a
237       // timeout. The fiber is stilling in the Worker::Work::waiting queue.
238       Waiting,
239 
240       // Queued: the Fiber is currently queued for execution in the
241       // Worker::Work::fibers queue.
242       Queued,
243 
244       // Running: the Fiber is currently executing.
245       Running,
246     };
247 
248     Fiber(Allocator::unique_ptr<OSFiber>&&, uint32_t id);
249 
250     // switchTo() switches execution to the given fiber.
251     // switchTo() must only be called on the currently executing fiber.
252     void switchTo(Fiber*);
253 
254     // create() constructs and returns a new fiber with the given identifier,
255     // stack size and func that will be executed when switched to.
256     static Allocator::unique_ptr<Fiber> create(
257         Allocator* allocator,
258         uint32_t id,
259         size_t stackSize,
260         const std::function<void()>& func);
261 
262     // createFromCurrentThread() constructs and returns a new fiber with the
263     // given identifier for the current thread.
264     static Allocator::unique_ptr<Fiber> createFromCurrentThread(
265         Allocator* allocator,
266         uint32_t id);
267 
268     // toString() returns a string representation of the given State.
269     // Used for debugging.
270     static const char* toString(State state);
271 
272     Allocator::unique_ptr<OSFiber> const impl;
273     Worker* const worker;
274     State state = State::Running;  // Guarded by Worker's work.mutex.
275   };
276 
277  private:
278   Scheduler(const Scheduler&) = delete;
279   Scheduler(Scheduler&&) = delete;
280   Scheduler& operator=(const Scheduler&) = delete;
281   Scheduler& operator=(Scheduler&&) = delete;
282 
283   // Maximum number of worker threads.
284   static constexpr size_t MaxWorkerThreads = 256;
285 
286   // WaitingFibers holds all the fibers waiting on a timeout.
287   struct WaitingFibers {
288     inline WaitingFibers(Allocator*);
289 
290     // operator bool() returns true iff there are any wait fibers.
291     inline operator bool() const;
292 
293     // take() returns the next fiber that has exceeded its timeout, or nullptr
294     // if there are no fibers that have yet exceeded their timeouts.
295     inline Fiber* take(const TimePoint& timeout);
296 
297     // next() returns the timepoint of the next fiber to timeout.
298     // next() can only be called if operator bool() returns true.
299     inline TimePoint next() const;
300 
301     // add() adds another fiber and timeout to the list of waiting fibers.
302     inline void add(const TimePoint& timeout, Fiber* fiber);
303 
304     // erase() removes the fiber from the waiting list.
305     inline void erase(Fiber* fiber);
306 
307     // contains() returns true if fiber is waiting.
308     inline bool contains(Fiber* fiber) const;
309 
310    private:
311     struct Timeout {
312       TimePoint timepoint;
313       Fiber* fiber;
314       inline bool operator<(const Timeout&) const;
315     };
316     containers::set<Timeout, std::less<Timeout>> timeouts;
317     containers::unordered_map<Fiber*, TimePoint> fibers;
318   };
319 
320   // TODO: Implement a queue that recycles elements to reduce number of
321   // heap allocations.
322   using TaskQueue = containers::deque<Task>;
323   using FiberQueue = containers::deque<Fiber*>;
324   using FiberSet = containers::unordered_set<Fiber*>;
325 
326   // Workers execute Tasks on a single thread.
327   // Once a task is started, it may yield to other tasks on the same Worker.
328   // Tasks are always resumed by the same Worker.
329   class Worker {
330    public:
331     enum class Mode {
332       // Worker will spawn a background thread to process tasks.
333       MultiThreaded,
334 
335       // Worker will execute tasks whenever it yields.
336       SingleThreaded,
337     };
338 
339     Worker(Scheduler* scheduler, Mode mode, uint32_t id);
340 
341     // start() begins execution of the worker.
342     void start() EXCLUDES(work.mutex);
343 
344     // stop() ceases execution of the worker, blocking until all pending
345     // tasks have fully finished.
346     void stop() EXCLUDES(work.mutex);
347 
348     // wait() suspends execution of the current task until the predicate pred
349     // returns true or the optional timeout is reached.
350     // See Fiber::wait() for more information.
351     MARL_EXPORT
352     bool wait(marl::lock& lock, const TimePoint* timeout, const Predicate& pred)
353         EXCLUDES(work.mutex);
354 
355     // wait() suspends execution of the current task until the fiber is
356     // notified, or the optional timeout is reached.
357     // See Fiber::wait() for more information.
358     MARL_EXPORT
359     bool wait(const TimePoint* timeout) EXCLUDES(work.mutex);
360 
361     // suspend() suspends the currently executing Fiber until the fiber is
362     // woken with a call to enqueue(Fiber*), or automatically sometime after the
363     // optional timeout.
364     void suspend(const TimePoint* timeout) REQUIRES(work.mutex);
365 
366     // enqueue(Fiber*) enqueues resuming of a suspended fiber.
367     void enqueue(Fiber* fiber) EXCLUDES(work.mutex);
368 
369     // enqueue(Task&&) enqueues a new, unstarted task.
370     void enqueue(Task&& task) EXCLUDES(work.mutex);
371 
372     // tryLock() attempts to lock the worker for task enqueuing.
373     // If the lock was successful then true is returned, and the caller must
374     // call enqueueAndUnlock().
375     bool tryLock() EXCLUDES(work.mutex) TRY_ACQUIRE(true, work.mutex);
376 
377     // enqueueAndUnlock() enqueues the task and unlocks the worker.
378     // Must only be called after a call to tryLock() which returned true.
379     // _Releases_lock_(work.mutex)
380     void enqueueAndUnlock(Task&& task) REQUIRES(work.mutex) RELEASE(work.mutex);
381 
382     // runUntilShutdown() processes all tasks and fibers until there are no more
383     // and shutdown is true, upon runUntilShutdown() returns.
384     void runUntilShutdown() REQUIRES(work.mutex);
385 
386     // steal() attempts to steal a Task from the worker for another worker.
387     // Returns true if a task was taken and assigned to out, otherwise false.
388     bool steal(Task& out) EXCLUDES(work.mutex);
389 
390     // getCurrent() returns the Worker currently bound to the current
391     // thread.
392     static inline Worker* getCurrent();
393 
394     // getCurrentFiber() returns the Fiber currently being executed.
395     inline Fiber* getCurrentFiber() const;
396 
397     // Unique identifier of the Worker.
398     const uint32_t id;
399 
400    private:
401     // run() is the task processing function for the worker.
402     // run() processes tasks until stop() is called.
403     void run() REQUIRES(work.mutex);
404 
405     // createWorkerFiber() creates a new fiber that when executed calls
406     // run().
407     Fiber* createWorkerFiber() REQUIRES(work.mutex);
408 
409     // switchToFiber() switches execution to the given fiber. The fiber
410     // must belong to this worker.
411     void switchToFiber(Fiber*) REQUIRES(work.mutex);
412 
413     // runUntilIdle() executes all pending tasks and then returns.
414     void runUntilIdle() REQUIRES(work.mutex);
415 
416     // waitForWork() blocks until new work is available, potentially calling
417     // spinForWork().
418     void waitForWork() REQUIRES(work.mutex);
419 
420     // spinForWorkAndLock() attempts to steal work from another Worker, and keeps
421     // the thread awake for a short duration. This reduces overheads of
422     // frequently putting the thread to sleep and re-waking. It locks the mutex
423     // before returning so that a stolen task cannot be re-stolen by other workers.
424     void spinForWorkAndLock() ACQUIRE(work.mutex);
425 
426     // enqueueFiberTimeouts() enqueues all the fibers that have finished
427     // waiting.
428     void enqueueFiberTimeouts() REQUIRES(work.mutex);
429 
430     inline void changeFiberState(Fiber* fiber,
431                                  Fiber::State from,
432                                  Fiber::State to) const REQUIRES(work.mutex);
433 
434     inline void setFiberState(Fiber* fiber, Fiber::State to) const
435         REQUIRES(work.mutex);
436 
437     // Work holds tasks and fibers that are enqueued on the Worker.
438     struct Work {
439       inline Work(Allocator*);
440 
441       std::atomic<uint64_t> num = {0};  // tasks.size() + fibers.size()
442       GUARDED_BY(mutex) uint64_t numBlockedFibers = 0;
443       GUARDED_BY(mutex) TaskQueue tasks;
444       GUARDED_BY(mutex) FiberQueue fibers;
445       GUARDED_BY(mutex) WaitingFibers waiting;
446       GUARDED_BY(mutex) bool notifyAdded = true;
447       std::condition_variable added;
448       marl::mutex mutex;
449 
450       template <typename F>
451       inline void wait(F&&) REQUIRES(mutex);
452     };
453 
454     // https://en.wikipedia.org/wiki/Xorshift
455     class FastRnd {
456      public:
operator()457       inline uint64_t operator()() {
458         x ^= x << 13;
459         x ^= x >> 7;
460         x ^= x << 17;
461         return x;
462       }
463 
464      private:
465       uint64_t x = std::chrono::system_clock::now().time_since_epoch().count();
466     };
467 
468     // The current worker bound to the current thread.
469     MARL_DECLARE_THREAD_LOCAL(Worker*, current);
470 
471     Mode const mode;
472     Scheduler* const scheduler;
473     Allocator::unique_ptr<Fiber> mainFiber;
474     Fiber* currentFiber = nullptr;
475     Thread thread;
476     Work work;
477     FiberSet idleFibers;  // Fibers that have completed which can be reused.
478     containers::vector<Allocator::unique_ptr<Fiber>, 16>
479         workerFibers;  // All fibers created by this worker.
480     FastRnd rng;
481     bool shutdown = false;
482   };
483 
484   // stealWork() attempts to steal a task from the worker with the given id.
485   // Returns true if a task was stolen and assigned to out, otherwise false.
486   bool stealWork(Worker* thief, uint64_t from, Task& out);
487 
488   // onBeginSpinning() is called when a Worker calls spinForWork().
489   // The scheduler will prioritize this worker for new tasks to try to prevent
490   // it going to sleep.
491   void onBeginSpinning(int workerId);
492 
493   // setBound() sets the scheduler bound to the current thread.
494   static void setBound(Scheduler* scheduler);
495 
496   // The scheduler currently bound to the current thread.
497   MARL_DECLARE_THREAD_LOCAL(Scheduler*, bound);
498 
499   // The immutable configuration used to build the scheduler.
500   const Config cfg;
501 
502   std::array<std::atomic<int>, MaxWorkerThreads> spinningWorkers;
503   std::atomic<unsigned int> nextSpinningWorkerIdx = {0x8000000};
504 
505   std::atomic<unsigned int> nextEnqueueIndex = {0};
506   std::array<Worker*, MaxWorkerThreads> workerThreads;
507 
508   struct SingleThreadedWorkers {
509     inline SingleThreadedWorkers(Allocator*);
510 
511     using WorkerByTid =
512         containers::unordered_map<std::thread::id,
513                                   Allocator::unique_ptr<Worker>>;
514     marl::mutex mutex;
515     GUARDED_BY(mutex) std::condition_variable unbind;
516     GUARDED_BY(mutex) WorkerByTid byTid;
517   };
518   SingleThreadedWorkers singleThreadedWorkers;
519 };
520 
521 ////////////////////////////////////////////////////////////////////////////////
522 // Scheduler::Config
523 ////////////////////////////////////////////////////////////////////////////////
setAllocator(Allocator * alloc)524 Scheduler::Config& Scheduler::Config::setAllocator(Allocator* alloc) {
525   allocator = alloc;
526   return *this;
527 }
528 
setFiberStackSize(size_t size)529 Scheduler::Config& Scheduler::Config::setFiberStackSize(size_t size) {
530   fiberStackSize = size;
531   return *this;
532 }
533 
setWorkerThreadCount(int count)534 Scheduler::Config& Scheduler::Config::setWorkerThreadCount(int count) {
535   workerThread.count = count;
536   return *this;
537 }
538 
setWorkerThreadInitializer(const ThreadInitializer & initializer)539 Scheduler::Config& Scheduler::Config::setWorkerThreadInitializer(
540     const ThreadInitializer& initializer) {
541   workerThread.initializer = initializer;
542   return *this;
543 }
544 
setWorkerThreadAffinityPolicy(const std::shared_ptr<Thread::Affinity::Policy> & policy)545 Scheduler::Config& Scheduler::Config::setWorkerThreadAffinityPolicy(
546     const std::shared_ptr<Thread::Affinity::Policy>& policy) {
547   workerThread.affinityPolicy = policy;
548   return *this;
549 }
550 
551 ////////////////////////////////////////////////////////////////////////////////
552 // Scheduler::Fiber
553 ////////////////////////////////////////////////////////////////////////////////
554 template <typename Clock, typename Duration>
wait(marl::lock & lock,const std::chrono::time_point<Clock,Duration> & timeout,const Predicate & pred)555 bool Scheduler::Fiber::wait(
556     marl::lock& lock,
557     const std::chrono::time_point<Clock, Duration>& timeout,
558     const Predicate& pred) {
559   using ToDuration = typename TimePoint::duration;
560   using ToClock = typename TimePoint::clock;
561   auto tp = std::chrono::time_point_cast<ToDuration, ToClock>(timeout);
562   return worker->wait(lock, &tp, pred);
563 }
564 
wait()565 void Scheduler::Fiber::wait() {
566   worker->wait(nullptr);
567 }
568 
569 template <typename Clock, typename Duration>
wait(const std::chrono::time_point<Clock,Duration> & timeout)570 bool Scheduler::Fiber::wait(
571     const std::chrono::time_point<Clock, Duration>& timeout) {
572   using ToDuration = typename TimePoint::duration;
573   using ToClock = typename TimePoint::clock;
574   auto tp = std::chrono::time_point_cast<ToDuration, ToClock>(timeout);
575   return worker->wait(&tp);
576 }
577 
getCurrent()578 Scheduler::Worker* Scheduler::Worker::getCurrent() {
579   return Worker::current;
580 }
581 
getCurrentFiber()582 Scheduler::Fiber* Scheduler::Worker::getCurrentFiber() const {
583   return currentFiber;
584 }
585 
586 // schedule() schedules the task T to be asynchronously called using the
587 // currently bound scheduler.
schedule(Task && t)588 inline void schedule(Task&& t) {
589   MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
590   auto scheduler = Scheduler::get();
591   scheduler->enqueue(std::move(t));
592 }
593 
594 // schedule() schedules the function f to be asynchronously called with the
595 // given arguments using the currently bound scheduler.
596 template <typename Function, typename... Args>
schedule(Function && f,Args &&...args)597 inline void schedule(Function&& f, Args&&... args) {
598   MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
599   auto scheduler = Scheduler::get();
600   scheduler->enqueue(
601       Task(std::bind(std::forward<Function>(f), std::forward<Args>(args)...)));
602 }
603 
604 // schedule() schedules the function f to be asynchronously called using the
605 // currently bound scheduler.
606 template <typename Function>
schedule(Function && f)607 inline void schedule(Function&& f) {
608   MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
609   auto scheduler = Scheduler::get();
610   scheduler->enqueue(Task(std::forward<Function>(f)));
611 }
612 
613 }  // namespace marl
614 
615 #endif  // marl_scheduler_h
616