xref: /aosp_15_r20/external/pytorch/caffe2/utils/threadpool/ThreadPool.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifndef CAFFE2_UTILS_THREADPOOL_H_
2 #define CAFFE2_UTILS_THREADPOOL_H_
3 
4 #include "ThreadPoolCommon.h"
5 
6 #include <atomic>
7 #include <functional>
8 #include <memory>
9 #include <mutex>
10 #include <vector>
11 
12 #include "caffe2/core/common.h"
13 
14 //
15 // A work-stealing threadpool loosely based off of pthreadpool
16 //
17 
18 namespace caffe2 {
19 
20 struct Task;
21 class WorkersPool;
22 
23 constexpr size_t kCacheLineSize = 64;
24 
25 // A threadpool with the given number of threads.
26 // NOTE: the kCacheLineSize alignment is present only for cache
27 // performance, and is not strictly enforced (for example, when
28 // the object is created on the heap). Thus, in order to avoid
29 // misaligned intrinsics, no SSE instructions shall be involved in
30 // the ThreadPool implementation.
31 // Note: alignas is disabled because some compilers do not deal with
32 // TORCH_API and alignas annotations at the same time.
33 class TORCH_API /*alignas(kCacheLineSize)*/ ThreadPool {
34  public:
35   static ThreadPool* createThreadPool(int numThreads);
36   static std::unique_ptr<ThreadPool> defaultThreadPool();
37   virtual ~ThreadPool() = default;
38   // Returns the number of threads currently in use
39   virtual int getNumThreads() const = 0;
40   virtual void setNumThreads(size_t numThreads) = 0;
41 
42   // Sets the minimum work size (range) for which to invoke the
43   // threadpool; work sizes smaller than this will just be run on the
44   // main (calling) thread
setMinWorkSize(size_t size)45   void setMinWorkSize(size_t size) {
46     std::lock_guard<std::mutex> guard(executionMutex_);
47     minWorkSize_ = size;
48   }
49 
getMinWorkSize()50   size_t getMinWorkSize() const {
51     return minWorkSize_;
52   }
53   virtual void run(const std::function<void(int, size_t)>& fn, size_t range) = 0;
54 
55   // Run an arbitrary function in a thread-safe manner accessing the Workers
56   // Pool
57   virtual void withPool(const std::function<void(WorkersPool*)>& fn) = 0;
58 
59  protected:
60   static size_t defaultNumThreads_;
61   mutable std::mutex executionMutex_;
62   size_t minWorkSize_;
63 };
64 
65 size_t getDefaultNumThreads();
66 } // namespace caffe2
67 
68 #endif // CAFFE2_UTILS_THREADPOOL_H_
69