xref: /aosp_15_r20/external/federated-compute/fcp/base/scheduler.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2018 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/base/scheduler.h"
18 
19 #include <array>
20 #include <functional>
21 #include <memory>
22 #include <queue>
23 #include <thread>  // NOLINT(build/c++11)
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/synchronization/blocking_counter.h"
28 #include "absl/synchronization/mutex.h"
29 
30 namespace fcp {
31 
32 namespace {
33 
34 // A helper class to track information about lifetime of an object.
35 // Uses a shared pointer (SharedMarker) to a boolean memory fragment
36 // which remembers if the object has been destroyed. Capturing the
37 // marker in a lambda gives us a clean way to CHECK fail if the
38 // object is accessed post destruction.
39 class LifetimeTracker {
40  public:
41   using SharedMarker = std::shared_ptr<bool>;
LifetimeTracker()42   LifetimeTracker() : marker_(std::make_shared<bool>(true)) {}
~LifetimeTracker()43   virtual ~LifetimeTracker() { *marker_ = false; }
marker()44   SharedMarker& marker() { return marker_; }
45 
46  private:
47   SharedMarker marker_;
48 };
49 
50 // Implementation of workers.
51 class WorkerImpl : public Worker, public LifetimeTracker {
52  public:
WorkerImpl(Scheduler * scheduler)53   explicit WorkerImpl(Scheduler* scheduler) : scheduler_(scheduler) {}
54 
55   ~WorkerImpl() override = default;
56 
Schedule(std::function<void ()> task)57   void Schedule(std::function<void()> task) override {
58     absl::MutexLock lock(&busy_);
59     steps_.emplace_back(std::move(task));
60     MaybeRunNext();
61   }
62 
63  private:
MaybeRunNext()64   void MaybeRunNext() ABSL_EXCLUSIVE_LOCKS_REQUIRED(busy_) {
65     if (running_ || steps_.empty()) {
66       // Already running, and next task will be executed when finished, or
67       // nothing to run.
68       return;
69     }
70     auto task = std::move(steps_.front());
71     steps_.pop_front();
72     running_ = true;
73     auto wrapped_task = MoveToLambda(std::move(task));
74     auto marker = this->marker();
75     scheduler_->Schedule([this, marker, wrapped_task] {
76       // Call the Task which is stored in wrapped_task.value.
77       (*wrapped_task)();
78 
79       // Run the next task.
80       FCP_CHECK(*marker) << "Worker destroyed before all tasks finished";
81       {
82         // Try run next task if any.
83         absl::MutexLock lock(&this->busy_);
84         this->running_ = false;
85         this->MaybeRunNext();
86       }
87     });
88   }
89 
90   Scheduler* scheduler_;
91   absl::Mutex busy_;
92   bool running_ ABSL_GUARDED_BY(busy_) = false;
93   std::deque<std::function<void()>> steps_ ABSL_GUARDED_BY(busy_);
94 };
95 
96 // Implementation of thread pools.
97 class ThreadPoolScheduler : public Scheduler {
98  public:
ThreadPoolScheduler(std::size_t thread_count)99   explicit ThreadPoolScheduler(std::size_t thread_count)
100       : idle_condition_(absl::Condition(IdleCondition, this)),
101         active_count_(thread_count) {
102     FCP_CHECK(thread_count > 0) << "invalid thread_count";
103 
104     // Create threads.
105     for (int i = 0; i < thread_count; ++i) {
106       threads_.emplace_back(std::thread([this] { this->PerThreadActivity(); }));
107     }
108   }
109 
~ThreadPoolScheduler()110   ~ThreadPoolScheduler() override {
111     {
112       absl::MutexLock lock(&busy_);
113       FCP_CHECK(IdleCondition(this))
114           << "Thread pool must be idle at destruction time";
115 
116       threads_should_join_ = true;
117       work_available_cond_var_.SignalAll();
118     }
119 
120     for (auto& thread : threads_) {
121       FCP_CHECK(thread.joinable()) << "Attempted to destroy a threadpool from "
122                                       "one of its running threads";
123       thread.join();
124     }
125   }
126 
Schedule(std::function<void ()> task)127   void Schedule(std::function<void()> task) override {
128     absl::MutexLock lock(&busy_);
129     todo_.push(std::move(task));
130     // Wake up a *single* thread to handle this task.
131     work_available_cond_var_.Signal();
132   }
133 
WaitUntilIdle()134   void WaitUntilIdle() override {
135     busy_.LockWhen(idle_condition_);
136     busy_.Unlock();
137   }
138 
IdleCondition(ThreadPoolScheduler * pool)139   static bool IdleCondition(ThreadPoolScheduler* pool)
140       ABSL_EXCLUSIVE_LOCKS_REQUIRED(pool->busy_) {
141     return pool->todo_.empty() && pool->active_count_ == 0;
142   }
143 
PerThreadActivity()144   void PerThreadActivity() {
145     for (;;) {
146       std::function<void()> task;
147       {
148         absl::MutexLock lock(&busy_);
149         --active_count_;
150         while (todo_.empty()) {
151           if (threads_should_join_) {
152             return;
153           }
154 
155           work_available_cond_var_.Wait(&busy_);
156         }
157 
158         // Destructor invariant
159         FCP_CHECK(!threads_should_join_);
160         task = std::move(todo_.front());
161         todo_.pop();
162         ++active_count_;
163       }
164 
165       task();
166     }
167   }
168 
169   // A vector of threads allocated for execution.
170   std::vector<std::thread> threads_;
171 
172   // A CondVar used to signal availability of tasks.
173   //
174   // We would prefer to use the more declarative absl::Condition instead,
175   // however, this one only allows to wake up all threads if a new task is
176   // available -- but we want to wake up only one in this case.
177   absl::CondVar work_available_cond_var_;
178 
179   // See IdleCondition
180   absl::Condition idle_condition_;
181 
182   // A mutex protecting mutable state in this class.
183   absl::Mutex busy_;
184 
185   // Set when worker threads should join instead of waiting for work.
186   bool threads_should_join_ ABSL_GUARDED_BY(busy_) = false;
187 
188   // Queue of tasks with work to do.
189   std::queue<std::function<void()>> todo_ ABSL_GUARDED_BY(busy_);
190 
191   // The number of threads currently doing work in this pool.
192   std::size_t active_count_ ABSL_GUARDED_BY(busy_);
193 };
194 
195 }  // namespace
196 
CreateWorker()197 std::unique_ptr<Worker> Scheduler::CreateWorker() {
198   return std::make_unique<WorkerImpl>(this);
199 }
200 
CreateThreadPoolScheduler(std::size_t thread_count)201 std::unique_ptr<Scheduler> CreateThreadPoolScheduler(std::size_t thread_count) {
202   return std::make_unique<ThreadPoolScheduler>(thread_count);
203 }
204 
205 }  // namespace fcp
206