xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/mkl_threadpool.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 
2 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3 
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7 
8     http://www.apache.org/licenses/LICENSE-2.0
9 
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16 
17 #ifndef TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_
18 #define TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_
19 #ifdef INTEL_MKL
20 
21 #include <list>
22 #include <memory>
23 #include <string>
24 #include <unordered_map>
25 #include <utility>
26 #include <vector>
27 
28 #include "dnnl_threadpool.hpp"
29 #include "dnnl.hpp"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/platform/threadpool.h"
32 #define EIGEN_USE_THREADS
33 
34 namespace tensorflow {
35 
36 #ifndef ENABLE_ONEDNN_OPENMP
37 using dnnl::threadpool_interop::threadpool_iface;
38 
39 // Divide 'n' units of work equally among 'teams' threads. If 'n' is not
40 // divisible by 'teams' and has a remainder 'r', the first 'r' teams have one
41 // unit of work more than the rest. Returns the range of work that belongs to
42 // the team 'tid'.
43 // Parameters
44 //   n        Total number of jobs.
45 //   team     Number of workers.
46 //   tid      Current thread_id.
47 //   n_start  start of range operated by the thread.
48 //   n_end    end of the range operated by the thread.
49 
50 template <typename T, typename U>
balance211(T n,U team,U tid,T * n_start,T * n_end)51 inline void balance211(T n, U team, U tid, T* n_start, T* n_end) {
52   if (team <= 1 || n == 0) {
53     *n_start = 0;
54     *n_end = n;
55     return;
56   }
57   T min_per_team = n / team;
58   T remainder = n - min_per_team * team;  // i.e., n % teams.
59   *n_start = tid * min_per_team + std::min(tid, remainder);
60   *n_end = *n_start + min_per_team + (tid < remainder);
61 }
62 
63 struct MklDnnThreadPool : public threadpool_iface {
64   MklDnnThreadPool() = default;
65 
66   MklDnnThreadPool(OpKernelContext* ctx, int num_threads = -1) {
67     eigen_interface_ = ctx->device()
68                            ->tensorflow_cpu_worker_threads()
69                            ->workers->AsEigenThreadPool();
70     num_threads_ =
71         (num_threads == -1) ? eigen_interface_->NumThreads() : num_threads;
72   }
get_num_threadsMklDnnThreadPool73   virtual int get_num_threads() const override { return num_threads_; }
get_in_parallelMklDnnThreadPool74   virtual bool get_in_parallel() const override {
75     return (eigen_interface_->CurrentThreadId() != -1) ? true : false;
76   }
get_flagsMklDnnThreadPool77   virtual uint64_t get_flags() const override { return ASYNCHRONOUS; }
parallel_forMklDnnThreadPool78   virtual void parallel_for(int n,
79                             const std::function<void(int, int)>& fn) override {
80     // Should never happen (handled by DNNL)
81     if (n == 0) return;
82 
83     // Should never happen (handled by DNNL)
84     if (n == 1) {
85       fn(0, 1);
86       return;
87     }
88 
89     int nthr = get_num_threads();
90     int njobs = std::min(n, nthr);
91     bool balance = (nthr < n);
92     for (int i = 0; i < njobs; i++) {
93       eigen_interface_->ScheduleWithHint(
94           [balance, i, n, njobs, fn]() {
95             if (balance) {
96               int start, end;
97               balance211(n, njobs, i, &start, &end);
98               for (int j = start; j < end; j++) fn(j, n);
99             } else {
100               fn(i, n);
101             }
102           },
103           i, i + 1);
104     }
105   }
~MklDnnThreadPoolMklDnnThreadPool106   ~MklDnnThreadPool() {}
107 
108  private:
109   Eigen::ThreadPoolInterface* eigen_interface_ = nullptr;
110   int num_threads_ = 1;  // Execute in caller thread.
111 };
112 
113 #else
114 
115 // This struct was just added to enable successful OMP-based build.
116 struct MklDnnThreadPool {
117   MklDnnThreadPool() = default;
118   MklDnnThreadPool(OpKernelContext* ctx) {}
119   MklDnnThreadPool(OpKernelContext* ctx, int num_threads) {}
120 };
121 
122 #endif  // !ENABLE_ONEDNN_OPENMP
123 
124 }  // namespace tensorflow
125 
126 #endif  // INTEL_MKL
127 #endif  // TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_
128