#include #if AT_PARALLEL_OPENMP || AT_PARALLEL_NATIVE #include #include #include #include namespace at { namespace { const int NOT_SET = -1; const int CONSUMED = -2; // Number of inter-op threads set by the user; // NOT_SET -> positive value -> CONSUMED // (CONSUMED - thread pool is initialized) // or // NOT_SET -> CONSUMED std::atomic num_interop_threads{NOT_SET}; // thread pool global instance is hidden, // users should use at::launch and get/set_num_interop_threads interface TaskThreadPoolBase& get_pool() { static std::shared_ptr pool = ThreadPoolRegistry()->Create( "C10", /* device_id */ 0, /* pool_size */ num_interop_threads.exchange(CONSUMED), /* create_new */ true); return *pool; } // Factory function for ThreadPoolRegistry std::shared_ptr create_c10_threadpool( int device_id, int pool_size, bool create_new) { // For now, the only accepted device id is 0 TORCH_CHECK(device_id == 0); // Create new thread pool TORCH_CHECK(create_new); return std::make_shared(pool_size); } } // namespace C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, create_c10_threadpool); void set_num_interop_threads(int nthreads) { TORCH_CHECK(nthreads > 0, "Expected positive number of threads"); int no_value = NOT_SET; TORCH_CHECK(num_interop_threads.compare_exchange_strong(no_value, nthreads), "Error: cannot set number of interop threads after parallel work " "has started or set_num_interop_threads called"); } int get_num_interop_threads() { at::internal::lazy_init_num_threads(); int nthreads = num_interop_threads.load(); if (nthreads > 0) { return nthreads; } else if (nthreads == NOT_SET) { // return default value return TaskThreadPoolBase::defaultNumThreads(); } else { return get_pool().size(); } } namespace internal { void launch_no_thread_state(std::function fn) { #if AT_EXPERIMENTAL_SINGLE_THREAD_POOL intraop_launch(std::move(fn)); #else get_pool().run(std::move(fn)); #endif } } // namespace internal void launch(std::function func) { // NOLINTNEXTLINE(modernize-avoid-bind) internal::launch_no_thread_state(std::bind([]( std::function f, ThreadLocalState thread_locals) { ThreadLocalStateGuard guard(thread_locals); f(); }, std::move(func), ThreadLocalState() )); } } // namespace at #endif