1 // Copyright 2019 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "osfiber.h" // Must come first. See osfiber_ucontext.h.
16
17 #include "marl/scheduler.h"
18
19 #include "marl/debug.h"
20 #include "marl/thread.h"
21 #include "marl/trace.h"
22
23 #if defined(_WIN32)
24 #include <intrin.h> // __nop()
25 #endif
26
27 // Enable to trace scheduler events.
28 #define ENABLE_TRACE_EVENTS 0
29
30 // Enable to print verbose debug logging.
31 #define ENABLE_DEBUG_LOGGING 0
32
33 #if ENABLE_TRACE_EVENTS
34 #define TRACE(...) MARL_SCOPED_EVENT(__VA_ARGS__)
35 #else
36 #define TRACE(...)
37 #endif
38
39 #if ENABLE_DEBUG_LOGGING
40 #define DBG_LOG(msg, ...) \
41 printf("%.3x " msg "\n", (int)threadID() & 0xfff, __VA_ARGS__)
42 #else
43 #define DBG_LOG(msg, ...)
44 #endif
45
46 #define ASSERT_FIBER_STATE(FIBER, STATE) \
47 MARL_ASSERT(FIBER->state == STATE, \
48 "fiber %d was in state %s, but expected %s", (int)FIBER->id, \
49 Fiber::toString(FIBER->state), Fiber::toString(STATE))
50
51 namespace {
52
53 #if ENABLE_DEBUG_LOGGING
54 // threadID() returns a uint64_t representing the currently executing thread.
55 // threadID() is only intended to be used for debugging purposes.
threadID()56 inline uint64_t threadID() {
57 auto id = std::this_thread::get_id();
58 return std::hash<std::thread::id>()(id);
59 }
60 #endif
61
nop()62 inline void nop() {
63 #if defined(_WIN32)
64 __nop();
65 #else
66 __asm__ __volatile__("nop");
67 #endif
68 }
69
setConfigDefaults(const marl::Scheduler::Config & cfgIn)70 inline marl::Scheduler::Config setConfigDefaults(
71 const marl::Scheduler::Config& cfgIn) {
72 marl::Scheduler::Config cfg{cfgIn};
73 if (cfg.workerThread.count > 0 && !cfg.workerThread.affinityPolicy) {
74 cfg.workerThread.affinityPolicy = marl::Thread::Affinity::Policy::anyOf(
75 marl::Thread::Affinity::all(cfg.allocator), cfg.allocator);
76 }
77 return cfg;
78 }
79
80 } // anonymous namespace
81
82 namespace marl {
83
84 ////////////////////////////////////////////////////////////////////////////////
85 // Scheduler
86 ////////////////////////////////////////////////////////////////////////////////
87 MARL_INSTANTIATE_THREAD_LOCAL(Scheduler*, Scheduler::bound, nullptr);
88
get()89 Scheduler* Scheduler::get() {
90 return bound;
91 }
92
setBound(Scheduler * scheduler)93 void Scheduler::setBound(Scheduler* scheduler) {
94 bound = scheduler;
95 }
96
bind()97 void Scheduler::bind() {
98 MARL_ASSERT(get() == nullptr, "Scheduler already bound");
99 setBound(this);
100 {
101 marl::lock lock(singleThreadedWorkers.mutex);
102 auto worker = cfg.allocator->make_unique<Worker>(
103 this, Worker::Mode::SingleThreaded, -1);
104 worker->start();
105 auto tid = std::this_thread::get_id();
106 singleThreadedWorkers.byTid.emplace(tid, std::move(worker));
107 }
108 }
109
unbind()110 void Scheduler::unbind() {
111 MARL_ASSERT(get() != nullptr, "No scheduler bound");
112 auto worker = Worker::getCurrent();
113 worker->stop();
114 {
115 marl::lock lock(get()->singleThreadedWorkers.mutex);
116 auto tid = std::this_thread::get_id();
117 auto& workers = get()->singleThreadedWorkers.byTid;
118 auto it = workers.find(tid);
119 MARL_ASSERT(it != workers.end(), "singleThreadedWorker not found");
120 MARL_ASSERT(it->second.get() == worker, "worker is not bound?");
121 workers.erase(it);
122 if (workers.empty()) {
123 get()->singleThreadedWorkers.unbind.notify_one();
124 }
125 }
126 setBound(nullptr);
127 }
128
Scheduler(const Config & config)129 Scheduler::Scheduler(const Config& config)
130 : cfg(setConfigDefaults(config)),
131 workerThreads{},
132 singleThreadedWorkers(config.allocator) {
133 for (int i = 0; i < cfg.workerThread.count; i++) {
134 spinningWorkers[i] = -1;
135 workerThreads[i] =
136 cfg.allocator->create<Worker>(this, Worker::Mode::MultiThreaded, i);
137 }
138 for (int i = 0; i < cfg.workerThread.count; i++) {
139 workerThreads[i]->start();
140 }
141 }
142
~Scheduler()143 Scheduler::~Scheduler() {
144 {
145 // Wait until all the single threaded workers have been unbound.
146 marl::lock lock(singleThreadedWorkers.mutex);
147 lock.wait(singleThreadedWorkers.unbind,
148 [this]() REQUIRES(singleThreadedWorkers.mutex) {
149 return singleThreadedWorkers.byTid.empty();
150 });
151 }
152
153 // Release all worker threads.
154 // This will wait for all in-flight tasks to complete before returning.
155 for (int i = cfg.workerThread.count - 1; i >= 0; i--) {
156 workerThreads[i]->stop();
157 }
158 for (int i = cfg.workerThread.count - 1; i >= 0; i--) {
159 cfg.allocator->destroy(workerThreads[i]);
160 }
161 }
162
enqueue(Task && task)163 void Scheduler::enqueue(Task&& task) {
164 if (task.is(Task::Flags::SameThread)) {
165 Worker::getCurrent()->enqueue(std::move(task));
166 return;
167 }
168 if (cfg.workerThread.count > 0) {
169 while (true) {
170 // Prioritize workers that have recently started spinning.
171 auto i = --nextSpinningWorkerIdx % cfg.workerThread.count;
172 auto idx = spinningWorkers[i].exchange(-1);
173 if (idx < 0) {
174 // If a spinning worker couldn't be found, round-robin the
175 // workers.
176 idx = nextEnqueueIndex++ % cfg.workerThread.count;
177 }
178
179 auto worker = workerThreads[idx];
180 if (worker->tryLock()) {
181 worker->enqueueAndUnlock(std::move(task));
182 return;
183 }
184 }
185 } else {
186 if (auto worker = Worker::getCurrent()) {
187 worker->enqueue(std::move(task));
188 } else {
189 MARL_FATAL(
190 "singleThreadedWorker not found. Did you forget to call "
191 "marl::Scheduler::bind()?");
192 }
193 }
194 }
195
config() const196 const Scheduler::Config& Scheduler::config() const {
197 return cfg;
198 }
199
stealWork(Worker * thief,uint64_t from,Task & out)200 bool Scheduler::stealWork(Worker* thief, uint64_t from, Task& out) {
201 if (cfg.workerThread.count > 0) {
202 auto thread = workerThreads[from % cfg.workerThread.count];
203 if (thread != thief) {
204 if (thread->steal(out)) {
205 return true;
206 }
207 }
208 }
209 return false;
210 }
211
onBeginSpinning(int workerId)212 void Scheduler::onBeginSpinning(int workerId) {
213 auto idx = nextSpinningWorkerIdx++ % cfg.workerThread.count;
214 spinningWorkers[idx] = workerId;
215 }
216
217 ////////////////////////////////////////////////////////////////////////////////
218 // Scheduler::Config
219 ////////////////////////////////////////////////////////////////////////////////
allCores()220 Scheduler::Config Scheduler::Config::allCores() {
221 return Config().setWorkerThreadCount(Thread::numLogicalCPUs());
222 }
223
224 ////////////////////////////////////////////////////////////////////////////////
225 // Scheduler::Fiber
226 ////////////////////////////////////////////////////////////////////////////////
Fiber(Allocator::unique_ptr<OSFiber> && impl,uint32_t id)227 Scheduler::Fiber::Fiber(Allocator::unique_ptr<OSFiber>&& impl, uint32_t id)
228 : id(id), impl(std::move(impl)), worker(Worker::getCurrent()) {
229 MARL_ASSERT(worker != nullptr, "No Scheduler::Worker bound");
230 }
231
current()232 Scheduler::Fiber* Scheduler::Fiber::current() {
233 auto worker = Worker::getCurrent();
234 return worker != nullptr ? worker->getCurrentFiber() : nullptr;
235 }
236
notify()237 void Scheduler::Fiber::notify() {
238 worker->enqueue(this);
239 }
240
wait(marl::lock & lock,const Predicate & pred)241 void Scheduler::Fiber::wait(marl::lock& lock, const Predicate& pred) {
242 MARL_ASSERT(worker == Worker::getCurrent(),
243 "Scheduler::Fiber::wait() must only be called on the currently "
244 "executing fiber");
245 worker->wait(lock, nullptr, pred);
246 }
247
switchTo(Fiber * to)248 void Scheduler::Fiber::switchTo(Fiber* to) {
249 MARL_ASSERT(worker == Worker::getCurrent(),
250 "Scheduler::Fiber::switchTo() must only be called on the "
251 "currently executing fiber");
252 if (to != this) {
253 impl->switchTo(to->impl.get());
254 }
255 }
256
create(Allocator * allocator,uint32_t id,size_t stackSize,const std::function<void ()> & func)257 Allocator::unique_ptr<Scheduler::Fiber> Scheduler::Fiber::create(
258 Allocator* allocator,
259 uint32_t id,
260 size_t stackSize,
261 const std::function<void()>& func) {
262 return allocator->make_unique<Fiber>(
263 OSFiber::createFiber(allocator, stackSize, func), id);
264 }
265
266 Allocator::unique_ptr<Scheduler::Fiber>
createFromCurrentThread(Allocator * allocator,uint32_t id)267 Scheduler::Fiber::createFromCurrentThread(Allocator* allocator, uint32_t id) {
268 return allocator->make_unique<Fiber>(
269 OSFiber::createFiberFromCurrentThread(allocator), id);
270 }
271
toString(State state)272 const char* Scheduler::Fiber::toString(State state) {
273 switch (state) {
274 case State::Idle:
275 return "Idle";
276 case State::Yielded:
277 return "Yielded";
278 case State::Queued:
279 return "Queued";
280 case State::Running:
281 return "Running";
282 case State::Waiting:
283 return "Waiting";
284 }
285 MARL_ASSERT(false, "bad fiber state");
286 return "<unknown>";
287 }
288
289 ////////////////////////////////////////////////////////////////////////////////
290 // Scheduler::WaitingFibers
291 ////////////////////////////////////////////////////////////////////////////////
WaitingFibers(Allocator * allocator)292 Scheduler::WaitingFibers::WaitingFibers(Allocator* allocator)
293 : timeouts(allocator), fibers(allocator) {}
294
operator bool() const295 Scheduler::WaitingFibers::operator bool() const {
296 return !fibers.empty();
297 }
298
take(const TimePoint & timeout)299 Scheduler::Fiber* Scheduler::WaitingFibers::take(const TimePoint& timeout) {
300 if (!*this) {
301 return nullptr;
302 }
303 auto it = timeouts.begin();
304 if (timeout < it->timepoint) {
305 return nullptr;
306 }
307 auto fiber = it->fiber;
308 timeouts.erase(it);
309 auto deleted = fibers.erase(fiber) != 0;
310 (void)deleted;
311 MARL_ASSERT(deleted, "WaitingFibers::take() maps out of sync");
312 return fiber;
313 }
314
next() const315 Scheduler::TimePoint Scheduler::WaitingFibers::next() const {
316 MARL_ASSERT(*this,
317 "WaitingFibers::next() called when there' no waiting fibers");
318 return timeouts.begin()->timepoint;
319 }
320
add(const TimePoint & timeout,Fiber * fiber)321 void Scheduler::WaitingFibers::add(const TimePoint& timeout, Fiber* fiber) {
322 timeouts.emplace(Timeout{timeout, fiber});
323 bool added = fibers.emplace(fiber, timeout).second;
324 (void)added;
325 MARL_ASSERT(added, "WaitingFibers::add() fiber already waiting");
326 }
327
erase(Fiber * fiber)328 void Scheduler::WaitingFibers::erase(Fiber* fiber) {
329 auto it = fibers.find(fiber);
330 if (it != fibers.end()) {
331 auto timeout = it->second;
332 auto erased = timeouts.erase(Timeout{timeout, fiber}) != 0;
333 (void)erased;
334 MARL_ASSERT(erased, "WaitingFibers::erase() maps out of sync");
335 fibers.erase(it);
336 }
337 }
338
contains(Fiber * fiber) const339 bool Scheduler::WaitingFibers::contains(Fiber* fiber) const {
340 return fibers.count(fiber) != 0;
341 }
342
operator <(const Timeout & o) const343 bool Scheduler::WaitingFibers::Timeout::operator<(const Timeout& o) const {
344 if (timepoint != o.timepoint) {
345 return timepoint < o.timepoint;
346 }
347 return fiber < o.fiber;
348 }
349
350 ////////////////////////////////////////////////////////////////////////////////
351 // Scheduler::Worker
352 ////////////////////////////////////////////////////////////////////////////////
353 MARL_INSTANTIATE_THREAD_LOCAL(Scheduler::Worker*,
354 Scheduler::Worker::current,
355 nullptr);
356
Worker(Scheduler * scheduler,Mode mode,uint32_t id)357 Scheduler::Worker::Worker(Scheduler* scheduler, Mode mode, uint32_t id)
358 : id(id),
359 mode(mode),
360 scheduler(scheduler),
361 work(scheduler->cfg.allocator),
362 idleFibers(scheduler->cfg.allocator) {}
363
start()364 void Scheduler::Worker::start() {
365 switch (mode) {
366 case Mode::MultiThreaded: {
367 auto allocator = scheduler->cfg.allocator;
368 auto& affinityPolicy = scheduler->cfg.workerThread.affinityPolicy;
369 auto affinity = affinityPolicy->get(id, allocator);
370 thread = Thread(std::move(affinity), [=, this] {
371 Thread::setName("Thread<%.2d>", int(id));
372
373 if (auto const& initFunc = scheduler->cfg.workerThread.initializer) {
374 initFunc(id);
375 }
376
377 Scheduler::setBound(scheduler);
378 Worker::current = this;
379 mainFiber = Fiber::createFromCurrentThread(scheduler->cfg.allocator, 0);
380 currentFiber = mainFiber.get();
381 {
382 marl::lock lock(work.mutex);
383 run();
384 }
385 mainFiber.reset();
386 Worker::current = nullptr;
387 });
388 break;
389 }
390 case Mode::SingleThreaded: {
391 Worker::current = this;
392 mainFiber = Fiber::createFromCurrentThread(scheduler->cfg.allocator, 0);
393 currentFiber = mainFiber.get();
394 break;
395 }
396 default:
397 MARL_ASSERT(false, "Unknown mode: %d", int(mode));
398 }
399 }
400
stop()401 void Scheduler::Worker::stop() {
402 switch (mode) {
403 case Mode::MultiThreaded: {
404 enqueue(Task([this] { shutdown = true; }, Task::Flags::SameThread));
405 thread.join();
406 break;
407 }
408 case Mode::SingleThreaded: {
409 marl::lock lock(work.mutex);
410 shutdown = true;
411 runUntilShutdown();
412 Worker::current = nullptr;
413 break;
414 }
415 default:
416 MARL_ASSERT(false, "Unknown mode: %d", int(mode));
417 }
418 }
419
wait(const TimePoint * timeout)420 bool Scheduler::Worker::wait(const TimePoint* timeout) {
421 DBG_LOG("%d: WAIT(%d)", (int)id, (int)currentFiber->id);
422 {
423 marl::lock lock(work.mutex);
424 suspend(timeout);
425 }
426 return timeout == nullptr || std::chrono::system_clock::now() < *timeout;
427 }
428
wait(lock & waitLock,const TimePoint * timeout,const Predicate & pred)429 bool Scheduler::Worker::wait(lock& waitLock,
430 const TimePoint* timeout,
431 const Predicate& pred) {
432 DBG_LOG("%d: WAIT(%d)", (int)id, (int)currentFiber->id);
433 while (!pred()) {
434 // Lock the work mutex to call suspend().
435 work.mutex.lock();
436
437 // Unlock the wait mutex with the work mutex lock held.
438 // Order is important here as we need to ensure that the fiber is not
439 // enqueued (via Fiber::notify()) between the waitLock.unlock() and fiber
440 // switch, otherwise the Fiber::notify() call may be ignored and the fiber
441 // is never woken.
442 waitLock.unlock_no_tsa();
443
444 // suspend the fiber.
445 suspend(timeout);
446
447 // Fiber resumed. We don't need the work mutex locked any more.
448 work.mutex.unlock();
449
450 // Re-lock to either return due to timeout, or call pred().
451 waitLock.lock_no_tsa();
452
453 // Check timeout.
454 if (timeout != nullptr && std::chrono::system_clock::now() >= *timeout) {
455 return false;
456 }
457
458 // Spurious wake up. Spin again.
459 }
460 return true;
461 }
462
suspend(const std::chrono::system_clock::time_point * timeout)463 void Scheduler::Worker::suspend(
464 const std::chrono::system_clock::time_point* timeout) {
465 // Current fiber is yielding as it is blocked.
466 if (timeout != nullptr) {
467 changeFiberState(currentFiber, Fiber::State::Running,
468 Fiber::State::Waiting);
469 work.waiting.add(*timeout, currentFiber);
470 } else {
471 changeFiberState(currentFiber, Fiber::State::Running,
472 Fiber::State::Yielded);
473 }
474
475 // First wait until there's something else this worker can do.
476 waitForWork();
477
478 work.numBlockedFibers++;
479
480 if (!work.fibers.empty()) {
481 // There's another fiber that has become unblocked, resume that.
482 work.num--;
483 auto to = containers::take(work.fibers);
484 ASSERT_FIBER_STATE(to, Fiber::State::Queued);
485 switchToFiber(to);
486 } else if (!idleFibers.empty()) {
487 // There's an old fiber we can reuse, resume that.
488 auto to = containers::take(idleFibers);
489 ASSERT_FIBER_STATE(to, Fiber::State::Idle);
490 switchToFiber(to);
491 } else {
492 // Tasks to process and no existing fibers to resume.
493 // Spawn a new fiber.
494 switchToFiber(createWorkerFiber());
495 }
496
497 work.numBlockedFibers--;
498
499 setFiberState(currentFiber, Fiber::State::Running);
500 }
501
tryLock()502 bool Scheduler::Worker::tryLock() {
503 return work.mutex.try_lock();
504 }
505
enqueue(Fiber * fiber)506 void Scheduler::Worker::enqueue(Fiber* fiber) {
507 bool notify = false;
508 {
509 marl::lock lock(work.mutex);
510 DBG_LOG("%d: ENQUEUE(%d %s)", (int)id, (int)fiber->id,
511 Fiber::toString(fiber->state));
512 switch (fiber->state) {
513 case Fiber::State::Running:
514 case Fiber::State::Queued:
515 return; // Nothing to do here - task is already queued or running.
516 case Fiber::State::Waiting:
517 work.waiting.erase(fiber);
518 break;
519 case Fiber::State::Idle:
520 case Fiber::State::Yielded:
521 break;
522 }
523 notify = work.notifyAdded;
524 work.fibers.push_back(fiber);
525 MARL_ASSERT(!work.waiting.contains(fiber),
526 "fiber is unexpectedly in the waiting list");
527 setFiberState(fiber, Fiber::State::Queued);
528 work.num++;
529 }
530
531 if (notify) {
532 work.added.notify_one();
533 }
534 }
535
enqueue(Task && task)536 void Scheduler::Worker::enqueue(Task&& task) {
537 work.mutex.lock();
538 enqueueAndUnlock(std::move(task));
539 }
540
enqueueAndUnlock(Task && task)541 void Scheduler::Worker::enqueueAndUnlock(Task&& task) {
542 auto notify = work.notifyAdded;
543 work.tasks.push_back(std::move(task));
544 work.num++;
545 work.mutex.unlock();
546 if (notify) {
547 work.added.notify_one();
548 }
549 }
550
steal(Task & out)551 bool Scheduler::Worker::steal(Task& out) {
552 if (work.num.load() == 0) {
553 return false;
554 }
555 if (!work.mutex.try_lock()) {
556 return false;
557 }
558 if (work.tasks.empty() || work.tasks.front().is(Task::Flags::SameThread)) {
559 work.mutex.unlock();
560 return false;
561 }
562 work.num--;
563 out = containers::take(work.tasks);
564 work.mutex.unlock();
565 return true;
566 }
567
run()568 void Scheduler::Worker::run() {
569 if (mode == Mode::MultiThreaded) {
570 MARL_NAME_THREAD("Thread<%.2d> Fiber<%.2d>", int(id), Fiber::current()->id);
571 // This is the entry point for a multi-threaded worker.
572 // Start with a regular condition-variable wait for work. This avoids
573 // starting the thread with a spinForWorkAndLock().
574 work.wait([this]() REQUIRES(work.mutex) {
575 return work.num > 0 || work.waiting || shutdown;
576 });
577 }
578 ASSERT_FIBER_STATE(currentFiber, Fiber::State::Running);
579 runUntilShutdown();
580 switchToFiber(mainFiber.get());
581 }
582
runUntilShutdown()583 void Scheduler::Worker::runUntilShutdown() {
584 while (!shutdown || work.num > 0 || work.numBlockedFibers > 0U) {
585 waitForWork();
586 runUntilIdle();
587 }
588 }
589
waitForWork()590 void Scheduler::Worker::waitForWork() {
591 MARL_ASSERT(work.num == work.fibers.size() + work.tasks.size(),
592 "work.num out of sync");
593 if (work.num > 0) {
594 return;
595 }
596
597 if (mode == Mode::MultiThreaded) {
598 scheduler->onBeginSpinning(id);
599 work.mutex.unlock();
600 spinForWorkAndLock();
601 }
602
603 work.wait([this]() REQUIRES(work.mutex) {
604 return work.num > 0 || (shutdown && work.numBlockedFibers == 0U);
605 });
606 if (work.waiting) {
607 enqueueFiberTimeouts();
608 }
609 }
610
enqueueFiberTimeouts()611 void Scheduler::Worker::enqueueFiberTimeouts() {
612 auto now = std::chrono::system_clock::now();
613 while (auto fiber = work.waiting.take(now)) {
614 changeFiberState(fiber, Fiber::State::Waiting, Fiber::State::Queued);
615 DBG_LOG("%d: TIMEOUT(%d)", (int)id, (int)fiber->id);
616 work.fibers.push_back(fiber);
617 work.num++;
618 }
619 }
620
changeFiberState(Fiber * fiber,Fiber::State from,Fiber::State to) const621 void Scheduler::Worker::changeFiberState(Fiber* fiber,
622 Fiber::State from,
623 Fiber::State to) const {
624 (void)from; // Unusued parameter when ENABLE_DEBUG_LOGGING is disabled.
625 DBG_LOG("%d: CHANGE_FIBER_STATE(%d %s -> %s)", (int)id, (int)fiber->id,
626 Fiber::toString(from), Fiber::toString(to));
627 ASSERT_FIBER_STATE(fiber, from);
628 fiber->state = to;
629 }
630
setFiberState(Fiber * fiber,Fiber::State to) const631 void Scheduler::Worker::setFiberState(Fiber* fiber, Fiber::State to) const {
632 DBG_LOG("%d: SET_FIBER_STATE(%d %s -> %s)", (int)id, (int)fiber->id,
633 Fiber::toString(fiber->state), Fiber::toString(to));
634 fiber->state = to;
635 }
636
spinForWorkAndLock()637 void Scheduler::Worker::spinForWorkAndLock() {
638 TRACE("SPIN");
639 Task stolen;
640
641 constexpr auto duration = std::chrono::milliseconds(1);
642 auto start = std::chrono::high_resolution_clock::now();
643 while (std::chrono::high_resolution_clock::now() - start < duration) {
644 for (int i = 0; i < 256; i++) // Empirically picked magic number!
645 {
646 // clang-format off
647 nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
648 nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
649 nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
650 nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
651 // clang-format on
652
653 if (work.num > 0) {
654 work.mutex.lock();
655 if (work.num > 0) {
656 return;
657 }
658 else {
659 // Our new task was stolen by another worker. Keep spinning.
660 work.mutex.unlock();
661 }
662 }
663 }
664
665 if (scheduler->stealWork(this, rng(), stolen)) {
666 work.mutex.lock();
667 work.tasks.emplace_back(std::move(stolen));
668 work.num++;
669 return;
670 }
671
672 std::this_thread::yield();
673 }
674 work.mutex.lock();
675 }
676
runUntilIdle()677 void Scheduler::Worker::runUntilIdle() {
678 ASSERT_FIBER_STATE(currentFiber, Fiber::State::Running);
679 MARL_ASSERT(work.num == work.fibers.size() + work.tasks.size(),
680 "work.num out of sync");
681 while (!work.fibers.empty() || !work.tasks.empty()) {
682 // Note: we cannot take and store on the stack more than a single fiber
683 // or task at a time, as the Fiber may yield and these items may get
684 // held on suspended fiber stack.
685
686 while (!work.fibers.empty()) {
687 work.num--;
688 auto fiber = containers::take(work.fibers);
689 // Sanity checks,
690 MARL_ASSERT(idleFibers.count(fiber) == 0, "dequeued fiber is idle");
691 MARL_ASSERT(fiber != currentFiber, "dequeued fiber is currently running");
692 ASSERT_FIBER_STATE(fiber, Fiber::State::Queued);
693
694 changeFiberState(currentFiber, Fiber::State::Running, Fiber::State::Idle);
695 auto added = idleFibers.emplace(currentFiber).second;
696 (void)added;
697 MARL_ASSERT(added, "fiber already idle");
698
699 switchToFiber(fiber);
700 changeFiberState(currentFiber, Fiber::State::Idle, Fiber::State::Running);
701 }
702
703 if (!work.tasks.empty()) {
704 work.num--;
705 auto task = containers::take(work.tasks);
706 work.mutex.unlock();
707
708 // Run the task.
709 task();
710
711 // std::function<> can carry arguments with complex destructors.
712 // Ensure these are destructed outside of the lock.
713 task = Task();
714
715 work.mutex.lock();
716 }
717 }
718 }
719
createWorkerFiber()720 Scheduler::Fiber* Scheduler::Worker::createWorkerFiber() {
721 auto fiberId = static_cast<uint32_t>(workerFibers.size() + 1);
722 DBG_LOG("%d: CREATE(%d)", (int)id, (int)fiberId);
723 auto fiber = Fiber::create(scheduler->cfg.allocator, fiberId,
724 scheduler->cfg.fiberStackSize,
725 [&]() REQUIRES(work.mutex) { run(); });
726 auto ptr = fiber.get();
727 workerFibers.emplace_back(std::move(fiber));
728 return ptr;
729 }
730
switchToFiber(Fiber * to)731 void Scheduler::Worker::switchToFiber(Fiber* to) {
732 DBG_LOG("%d: SWITCH(%d -> %d)", (int)id, (int)currentFiber->id, (int)to->id);
733 MARL_ASSERT(to == mainFiber.get() || idleFibers.count(to) == 0,
734 "switching to idle fiber");
735 auto from = currentFiber;
736 currentFiber = to;
737 from->switchTo(to);
738 }
739
740 ////////////////////////////////////////////////////////////////////////////////
741 // Scheduler::Worker::Work
742 ////////////////////////////////////////////////////////////////////////////////
Work(Allocator * allocator)743 Scheduler::Worker::Work::Work(Allocator* allocator)
744 : tasks(allocator), fibers(allocator), waiting(allocator) {}
745
746 template <typename F>
wait(F && f)747 void Scheduler::Worker::Work::wait(F&& f) {
748 notifyAdded = true;
749 if (waiting) {
750 mutex.wait_until_locked(added, waiting.next(), f);
751 } else {
752 mutex.wait_locked(added, f);
753 }
754 notifyAdded = false;
755 }
756
757 ////////////////////////////////////////////////////////////////////////////////
758 // Scheduler::Worker::Work
759 ////////////////////////////////////////////////////////////////////////////////
SingleThreadedWorkers(Allocator * allocator)760 Scheduler::SingleThreadedWorkers::SingleThreadedWorkers(Allocator* allocator)
761 : byTid(allocator) {}
762
763 } // namespace marl
764