xref: /aosp_15_r20/external/eigen/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
12 
13 // evaluator for thread pool device
14 #ifdef EIGEN_USE_THREADS
15 
16 namespace Eigen {
17 
18 template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
19 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> :
20     public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> > {
21 
22   typedef ThreadPoolDevice Device;
23 
24   typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
25   typedef TensorContractionEvaluatorBase<Self> Base;
26 
27   typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
28   typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
29   typedef typename XprType::Index Index;
30   typedef typename XprType::CoeffReturnType CoeffReturnType;
31   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
32 
33   enum {
34     Layout = TensorEvaluator<LeftArgType, Device>::Layout,
35   };
36 
37   // Most of the code is assuming that both input tensors are ColMajor. If the
38   // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
39   // If we want to compute A * B = C, where A is LHS and B is RHS, the code
40   // will pretend B is LHS and A is RHS.
41   typedef typename internal::conditional<
42     static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
43   typedef typename internal::conditional<
44     static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
45 
46   static const int LDims =
47       internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
48   static const int RDims =
49       internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
50   static const int ContractDims = internal::array_size<Indices>::value;
51 
52   typedef array<Index, LDims> left_dim_mapper_t;
53   typedef array<Index, RDims> right_dim_mapper_t;
54 
55   typedef array<Index, ContractDims> contract_t;
56   typedef array<Index, LDims - ContractDims> left_nocontract_t;
57   typedef array<Index, RDims - ContractDims> right_nocontract_t;
58 
59   static const int NumDims = LDims + RDims - 2 * ContractDims;
60 
61   typedef DSizes<Index, NumDims> Dimensions;
62 
63   // typedefs needed in evalTo
64   typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
65   typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
66   typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
67 
68   typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
69   typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
70 
71   TensorEvaluator(const XprType& op, const Device& device) :
72       Base(op, device) {}
73 
74   template <int Alignment>
75   void evalProduct(Scalar* buffer) const {
76     evalProductImpl<NoCallback, Alignment>(buffer, NoCallback());
77   }
78 
79   template <typename EvalToCallback, int Alignment>
80   void evalProductAsync(Scalar* buffer, EvalToCallback done) const {
81     evalProductImpl<EvalToCallback, Alignment>(buffer, std::move(done));
82   }
83 
84   template <typename DoneCallback, int Alignment>
85   void evalProductImpl(Scalar* buffer, DoneCallback done) const {
86     // This function computes a lot of heuristics in multiple steps, and it
87     // also has multiple exit points. To keep it sane, readable and all in one
88     // place, sync/async execution decision is made at runtime at the very end.
89     //
90     // (1) In sync mode we allocate Context on the stack, submit computations
91     //     to the device thread pool, and block on a barrier until it is
92     //     completed.
93     //
94     // (2) In async mode we allocate Context on the heap, and after all tasks
95     //     are finished, we call provided the done callback, and delete a
96     //     context from the heap.
97     //
98     // (*) EvalParallelContext & EvalShardedByInnerDimContext owns all the state
99     // and temporary buffers, requried for executing the tensor contraction.
100     // They are responsible for cleaning it up after contraction is done.
101     static const bool IsEvalInSyncMode =
102         std::is_same<DoneCallback, NoCallback>::value;
103 
104     const Index m = this->m_i_size;
105     const Index n = this->m_j_size;
106     const Index k = this->m_k_size;
107     if (m == 0 || n == 0 || k == 0) return;
108 
109     // Compute a set of algorithm parameters:
110     // - kernel block sizes (bm, bn, bk)
111     // - task grain sizes (number of kernels executed per task: gm, gn)
112     // - number of threads
113     // - sharding by row/column
114     // - parallel packing or first lhs then rhs
115     // and some derived parameters:
116     // - number of tasks (nm, nn, nk)
117     // - number of kernels (nm0, nn0)
118     // Unfortunately, all these parameters are tightly interdependent.
119     // So in some cases we first compute approximate values, then compute other
120     // values based on these approximations and then refine the approximations.
121 
122     // There are lots of heuristics here. There is some reasoning behind them,
123     // but ultimately they are just tuned on contraction benchmarks for
124     // different input configurations, thread counts and instruction sets.
125     // So feel free to question any of them.
126 
127     // Compute whether we want to shard by row or by column.
128     // This is a first approximation, it will be refined later. Since we don't
129     // know number of threads yet we use 2, because what's we are most
130     // interested in at this point is whether it makes sense to use
131     // parallelization at all or not.
132     bool shard_by_col = shardByCol(m, n, 2);
133 
134     // First approximation of kernel blocking sizes.
135     // Again, we don't know number of threads yet, so we use 2.
136     Index bm, bn, bk;
137     if (shard_by_col) {
138       internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index,
139                                           internal::ShardByCol>
140           blocking(k, m, n, 2);
141       bm = blocking.mc();
142       bn = blocking.nc();
143       bk = blocking.kc();
144     } else {
145       internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index,
146                                           internal::ShardByRow>
147           blocking(k, m, n, 2);
148       bm = blocking.mc();
149       bn = blocking.nc();
150       bk = blocking.kc();
151     }
152 
153     // Compute optimal number of threads.
154     // Note: we use bk instead of k here because we are interested in amount of
155     // _parallelizable_ computations, and computations are not parallelizable
156     // across k dimension.
157     const TensorOpCost cost =
158         contractionCost(m, n, bm, bn, bk, shard_by_col, false);
159     int num_threads = TensorCostModel<ThreadPoolDevice>::numThreads(
160         static_cast<double>(n) * m, cost, this->m_device.numThreads());
161     int num_threads_by_k = numThreadsInnerDim(m, n, k);
162     if (shardByInnerDim(m, n, k, num_threads, num_threads_by_k)) {
163       // We are in the scenario where it is more effective to shard by the
164       // inner dimension.
165       if (IsEvalInSyncMode) {
166         EvalShardedByInnerDimContext<DoneCallback> ctx(
167             this, num_threads_by_k, buffer, m, n, k, std::move(done));
168         ctx.template run<Alignment>();
169       } else {
170         auto* ctx = new EvalShardedByInnerDimContext<DoneCallback>(
171             this, num_threads_by_k, buffer, m, n, k, std::move(done));
172         ctx->template runAsync<Alignment>();
173       }
174 
175       return;
176     }
177 
178     // TODO(dvyukov): this is a stop-gap to prevent regressions while the cost
179     // model is not tuned. Remove this when the cost model is tuned.
180     if (n == 1) num_threads = 1;
181 
182     if (num_threads == 1) {
183       TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential,
184                                   Unaligned, (buffer));
185       if (!IsEvalInSyncMode) done();
186       return;
187     }
188 
189     // Now that we know number of threads, recalculate sharding and blocking.
190     shard_by_col = shardByCol(m, n, num_threads);
191     if (shard_by_col) {
192       internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index,
193                                           internal::ShardByCol>
194           blocking(k, m, n, num_threads);
195       bm = blocking.mc();
196       bn = blocking.nc();
197       bk = blocking.kc();
198     } else {
199       internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index,
200                                           internal::ShardByRow>
201           blocking(k, m, n, num_threads);
202       bm = blocking.mc();
203       bn = blocking.nc();
204       bk = blocking.kc();
205     }
206 
207     // Number of kernels for each dimension.
208     Index nm0 = divup(m, bm);
209     Index nn0 = divup(n, bn);
210     Index nk = divup(k, bk);
211 
212     // Calculate task grain size (number of kernels executed per task).
213     // This task size coarsening serves two purposes:
214     // 1. It reduces per-task overheads including synchronization overheads.
215     // 2. It allows to use caches better (reuse the same packed rhs in several
216     // consecutive kernels).
217     Index gm = 1;
218     Index gn = 1;
219     // If we are sharding by column, then we prefer to reduce rows first.
220     if (shard_by_col) {
221       gm = coarsenM(m, n, bm, bn, bk, gn, num_threads, shard_by_col);
222       gn = coarsenN(m, n, bm, bn, bk, gm, num_threads, shard_by_col);
223     } else {
224       gn = coarsenN(m, n, bm, bn, bk, gm, num_threads, shard_by_col);
225       gm = coarsenM(m, n, bm, bn, bk, gn, num_threads, shard_by_col);
226     }
227     // Number of tasks in each dimension.
228     Index nm = divup(nm0, gm);
229     Index nn = divup(nn0, gn);
230 
231     // If there is enough concurrency in the sharding dimension, we choose not
232     // to paralellize by the other dimension, and execute all kernels in sync
233     // mode. This reduces parallelism from the nm x nn down to nn
234     // (shard_by_col==true) or nm (shard_by_col==false).
235     const Index sharding_dim_tasks = shard_by_col ? nn : nm;
236     const int num_worker_threads = this->m_device.numThreadsInPool();
237 
238     // With small number of threads we want to make sure that we do not reduce
239     // parallelism too much. With large number of threads we trade maximum
240     // parallelism for better memory locality.
241     const float oversharding_factor =
242         num_worker_threads <= 4  ? 8.0 :
243         num_worker_threads <= 8  ? 4.0 :
244         num_worker_threads <= 16 ? 2.0 :
245         num_worker_threads <= 32 ? 1.0 :
246         num_worker_threads <= 64 ? 0.8 : /* num_worker_threads > 64 */ 0.6;
247 
248     const bool parallelize_by_sharding_dim_only =
249         sharding_dim_tasks >= oversharding_factor * num_worker_threads;
250 
251     // Last by not least, decide whether we want to issue both lhs and rhs
252     // packing in parallel; or issue lhs packing first, and then issue rhs
253     // packing when lhs packing completes (for !shard_by_col lhs and rhs are
254     // swapped). Parallel packing allows more parallelism (for both packing and
255     // kernels), while sequential packing provides better locality (once
256     // a thread finishes rhs packing it proceed to kernels with that rhs).
257     // First, we are interested in parallel packing if there are few tasks.
258     bool parallel_pack = num_threads >= nm * nn;
259     // Also do parallel packing if all data fits into L2$.
260     if (m * bk * Index(sizeof(LhsScalar)) + n * bk * Index(sizeof(RhsScalar)) <=
261         l2CacheSize() * num_threads)
262       parallel_pack = true;
263     // But don't do it if we will use each rhs only once. Locality seems to be
264     // more important in this case.
265     if ((shard_by_col ? nm : nn) == 1) parallel_pack = false;
266     // Also don't get in the way of parallelize_by_sharding_dim_only
267     // optimization.
268     if (parallelize_by_sharding_dim_only) parallel_pack = false;
269 
270     // TODO(ezhulnev): With if contexpr we don't need SyncEvalParallelContext.
271     if (IsEvalInSyncMode) {
272 #define CONTEXT_ARGS                                                        \
273   (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \
274    nn0, shard_by_col, parallel_pack, parallelize_by_sharding_dim_only,      \
275    NoCallback())                                                            \
276       .run()
277       TENSOR_CONTRACTION_DISPATCH(SyncEvalParallelContext, Alignment,
278                                   CONTEXT_ARGS);
279 #undef CONTEXT_ARGS
280 
281     } else {
282 #define CONTEXT_ARGS                                                        \
283   (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \
284    nn0, shard_by_col, parallel_pack, parallelize_by_sharding_dim_only,      \
285    std::move(done))
286       TENSOR_CONTRACTION_ASYNC_DISPATCH(EvalParallelContext, DoneCallback,
287                                         Alignment, CONTEXT_ARGS, run());
288 #undef CONTEXT_ARGS
289     }
290   }
291 
292   // ------------------------------------------------------------------------ //
293 
294   // Dummy struct to represent an empty DoneCallback.
295 
296   struct NoCallback {
297     void operator()() {
298       eigen_assert(false && "NoCallback should never be called");
299     }
300   };
301 
302   // ------------------------------------------------------------------------ //
303 
304   template <typename DoneCallback, typename Context>
305   class EvalParallelNotification;
306 
307   // Synchronous evaluation notification that blocks caller thread in Wait().
308   template <typename Context>
309   class EvalParallelNotification<NoCallback, Context> {
310    public:
311     EvalParallelNotification(Context*, NoCallback) {}
312     void Notify() { done_.Notify(); }
313     void Wait() { done_.Wait(); }
314    private:
315     Eigen::Notification done_;
316   };
317 
318   // Asynchronous evaluation notification that does not block in Wait().
319   template <typename DoneCallback, typename Context>
320   class EvalParallelNotification {
321    public:
322     EvalParallelNotification(Context* ctx, DoneCallback done)
323         : ctx_(ctx), done_(std::move(done)) {}
324 
325     void Notify() {
326       // Make a copy of done callback, because it will be destructed when we
327       // will delete context in the next line (EvalParallelNotification is a
328       // data member of EvalParallelContext class).
329       DoneCallback done_copy = std::move(done_);
330 
331       // Delete parallel evaluation context.
332       delete ctx_;
333 
334       // Now safely call the done callback.
335       done_copy();
336     }
337 
338     void Wait() {}
339 
340    private:
341     Context* ctx_;
342     DoneCallback done_;
343   };
344 
345   // Context orchestrates sync/async parallel contraction evaluation. When it is
346   // executed in asynchronous mode, it owns all the shared state that might be
347   // accessible by block packing and kernel tasks.
348 
349   template <typename DoneCallback, bool lhs_inner_dim_contiguous,
350             bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered,
351             int Alignment>
352   class EvalParallelContext {
353    public:
354     typedef internal::TensorContractionInputMapper<
355         LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
356         contract_t, internal::packet_traits<LhsScalar>::size,
357         lhs_inner_dim_contiguous, false, Unaligned>
358         LhsMapper;
359     typedef internal::TensorContractionInputMapper<
360         RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
361         contract_t, internal::packet_traits<RhsScalar>::size,
362         rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned>
363         RhsMapper;
364 
365     typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
366 
367     typedef internal::TensorContractionKernel<
368         Scalar, LhsScalar, RhsScalar, Index, OutputMapper, LhsMapper, RhsMapper>
369         TensorContractionKernel;
370 
371     typedef typename TensorContractionKernel::LhsBlock LhsBlock;
372     typedef typename TensorContractionKernel::RhsBlock RhsBlock;
373     typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
374 
375     EvalParallelContext(const Self* self, int num_threads, Scalar* buffer,
376                         Index tm, Index tn, Index tk, Index bm, Index bn,
377                         Index bk, Index nm, Index nn, Index nk, Index gm,
378                         Index gn, Index nm0, Index nn0, bool shard_by_col,
379                         bool parallel_pack,
380                         bool parallelize_by_sharding_dim_only,
381                         DoneCallback done)
382         : created_by_thread_id_(std::this_thread::get_id()),
383           done_(this, std::move(done)),
384           device_(self->m_device),
385           lhs_(self->m_leftImpl, self->m_left_nocontract_strides,
386                self->m_i_strides, self->m_left_contracting_strides,
387                self->m_k_strides),
388           rhs_(self->m_rightImpl, self->m_right_nocontract_strides,
389                self->m_j_strides, self->m_right_contracting_strides,
390                self->m_k_strides),
391           buffer_(buffer),
392           output_(buffer, tm),
393           output_kernel_(self->m_output_kernel),
394           tensor_contraction_params_(self->m_tensor_contraction_params),
395           num_threads_(num_threads),
396           shard_by_col_(shard_by_col),
397           parallel_pack_(parallel_pack),
398           parallelize_by_sharding_dim_only_(parallelize_by_sharding_dim_only),
399           m_(tm),
400           n_(tn),
401           k_(tk),
402           bm_(bm),
403           bn_(bn),
404           bk_(bk),
405           nm_(nm),
406           nn_(nn),
407           nk_(nk),
408           gm_(gm),
409           gn_(gn),
410           nm0_(nm0),
411           nn0_(nn0),
412           kernel_(m_, k_, n_, bm_, bk_, bn_),
413           num_thread_local_allocations_(0),
414           // We reserve 2X more capacity for a thread local values, than the
415           // number of threads in the pool to efficiently handle task stealing
416           // by threads that are not managed by the pool.
417           thread_local_capacity(2 * (parallelize_by_sharding_dim_only_
418                                          ? device_.numThreadsInPool()
419                                          : 0)),
420           // We will use only one of the Lhs/Rhs thread local storage depending
421           // on the shard_by_col value and we parallelize by sharding dim ONLY.
422           lhs_thread_local_blocks_(shard_by_col_ ? 0 : thread_local_capacity,
423                                    {*this}, {*this}),
424           rhs_thread_local_blocks_(shard_by_col_ ? thread_local_capacity : 0,
425                                    {*this}, {*this}) {
426       // These two options are mutually exclusive.
427       eigen_assert(!(parallel_pack && parallelize_by_sharding_dim_only));
428 
429       for (Index x = 0; x < P; x++) {
430         // Normal number of notifications for k slice switch is
431         // nm_ + nn_ + nm_ * nn_. However, first P - 1 slices will receive only
432         // nm_ + nn_ notifications, because they will not receive notifications
433         // from preceding kernels.
434         state_switch_[x] =
435             x == 0
436                 ? 1
437                 : (parallel_pack_ ? nn_ + nm_ : (shard_by_col_ ? nn_ : nm_)) +
438                       (x == P - 1 ? nm_ * nn_ : 0);
439         state_packing_ready_[x] =
440             parallel_pack_ ? 0 : (shard_by_col_ ? nm_ : nn_);
441         state_kernel_[x] = new std::atomic<uint8_t>*[nm_];
442         for (Index m = 0; m < nm_; m++) {
443           state_kernel_[x][m] = new std::atomic<uint8_t>[nn_];
444           // Kernels generally receive 3 notifications (previous kernel + 2
445           // packing), but the first slice won't get notifications from previous
446           // kernels.
447           for (Index n = 0; n < nn_; n++)
448             state_kernel_[x][m][n].store(
449                 (x == 0 ? 0 : 1) + (parallel_pack_ ? 2 : 1),
450                 std::memory_order_relaxed);
451         }
452       }
453 
454       // Allocate memory for packed rhs/lhs matrices.
455       packed_mem_ = kernel_.allocateSlices(            //
456           device_,                                     //
457           /*num_lhs=*/nm0_,                            //
458           /*num_rhs=*/nn0_,                            //
459           /*num_slices=*/std::min<Index>(nk_, P - 1),  //
460           packed_lhs_, packed_rhs_);
461 
462       if (parallelize_by_sharding_dim_only_) {
463         const int num_worker_threads = device_.numThreadsInPool();
464 
465         if (shard_by_col) {
466           can_use_thread_local_packed_ = new std::atomic<bool>[nn_];
467           for (int i = 0; i < nn_; ++i)
468             can_use_thread_local_packed_[i].store(true,
469                                                   std::memory_order_relaxed);
470 
471           Index num_blocks = num_worker_threads * gn_;
472           thread_local_pre_alocated_mem_ = kernel_.allocateSlices(  //
473               device_,                                              //
474               /*num_lhs=*/0,                                        //
475               /*num_rhs=*/num_blocks,                               //
476               /*num_slices=*/1,                                     //
477               /*lhs_blocks=*/nullptr, &rhs_thread_local_pre_allocated_);
478 
479         } else {
480           can_use_thread_local_packed_ = new std::atomic<bool>[nm_];
481           for (int i = 0; i < nm_; ++i)
482             can_use_thread_local_packed_[i].store(true,
483                                                   std::memory_order_relaxed);
484 
485           Index num_blocks = num_worker_threads * gm_;
486           thread_local_pre_alocated_mem_ = kernel_.allocateSlices(  //
487               device_,                                              //
488               /*num_lhs=*/num_blocks,                               //
489               /*num_rhs=*/0,                                        //
490               /*num_slices=*/1, &lhs_thread_local_pre_allocated_,   //
491               /*rhs_blocks=*/nullptr);
492         }
493       }
494     }
495 
496     ~EvalParallelContext() {
497       for (Index x = 0; x < P; x++) {
498         for (Index m = 0; m < nm_; m++) delete[] state_kernel_[x][m];
499         delete[] state_kernel_[x];
500       }
501       kernel_.deallocate(device_, packed_mem_);
502       if (parallelize_by_sharding_dim_only_) {
503         kernel_.deallocate(device_, thread_local_pre_alocated_mem_);
504         delete[] can_use_thread_local_packed_;
505       }
506     }
507 
508     void run() {
509       // Kick off packing of the first slice.
510       signal_switch(0, 1);
511 
512       // Wait for overall completion.
513       //
514       // If parallel evaluation is executed in async mode, this is a no-op, and
515       // Wait() will return immediately. In synchronous mode it will block the
516       // caller thread until it will receive notification from last task.
517       //
518       // In async mode, last task when completed will call done callback from
519       // the same thread, and will delete this context.
520       //
521       // TODO(dvyukov): This wait can lead to deadlock if contraction is
522       // evaluated in synchronous mode. If nthreads contractions are
523       // concurrently submitted from worker threads, this wait will block all
524       // worker threads and the system will deadlock.
525       done_.Wait();
526     }
527 
528    private:
529     std::thread::id created_by_thread_id_;
530 
531     // This notification is specialized on the type of DoneCallback and can be
532     // blocking or non-blocking.
533     EvalParallelNotification<DoneCallback, EvalParallelContext> done_;
534 
535     const Device& device_;
536     LhsMapper lhs_;
537     RhsMapper rhs_;
538     Scalar* const buffer_;
539     OutputMapper output_;
540     OutputKernelType output_kernel_;
541     TensorContractionParams tensor_contraction_params_;
542     const int num_threads_;
543     const bool shard_by_col_;
544     const bool parallel_pack_;
545     const bool parallelize_by_sharding_dim_only_;
546     // Matrix sizes.
547     const Index m_;
548     const Index n_;
549     const Index k_;
550     // Block sizes.
551     const Index bm_;
552     const Index bn_;
553     const Index bk_;
554     // Number of tasks.
555     const Index nm_;
556     const Index nn_;
557     const Index nk_;
558     // Task grain sizes (number of kernels executed per task).
559     const Index gm_;
560     const Index gn_;
561     // Number of blocks (this is different from ni_/nn_ because of task size
562     // coarsening).
563     const Index nm0_;
564     const Index nn0_;
565     // Tensor contraction kernel.
566     TensorContractionKernel kernel_;
567 
568     // Parallelization strategy.
569     //
570     // Blocks related to the same k block can run in parallel because they write
571     // to different output blocks. So we parallelize within k slices, this
572     // gives us parallelism level of m x n. Before we can start any kernels
573     // related to k-th slice, we need to issue m lhs packing tasks and n rhs
574     // packing tasks.
575     //
576     // However, there is a bottleneck when we are finishing kernels for k-th
577     // slice (at the very end there is only 1 runnable kernel). To mitigate this
578     // bottleneck we allow kernels from k-th and k+1-th slices to run in
579     // parallel. Note that (m, n, k) and (m, n, k+1) kernels write to the same
580     // output block, so they must not run in parallel.
581     //
582     // This gives us the following dependency graph.
583     // On each k slice we have m x n kernel tasks, m lhs paking tasks and n rhs
584     // packing tasks.
585     // Kernel (m, n, k) can start when:
586     //  - kernel (m, n, k-1) has finished
587     //  - lhs packing (m, k) has finished
588     //  - rhs packing (n, k) has finished
589     // Lhs/rhs packing can start when:
590     //  - all k-1 packing has finished (artificially imposed to limit amount of
591     //  parallel packing)
592     //
593     // On top of that we limit runnable tasks to two consecutive k slices.
594     // This is done to limit amount of memory we need for packed lhs/rhs
595     // (for each k slice we need m*bk + n*bk memory in packed_lhs_/packed_rhs_).
596     //
597     // state_switch_ tracks when we are ready to switch to the next k slice.
598     // state_kernel_[m][n] tracks when we are ready to kick off kernel (m, n).
599     // These variable are rolling over 3 consecutive k slices: first two we are
600     // actively executing + one to track completion of kernels in the second
601     // slice.
602     static const Index P = 3;
603 
604     // Handle to the allocated temporary storage for Lhs/Rhs blocks.
605     BlockMemHandle packed_mem_;
606     std::vector<LhsBlock> packed_lhs_[P - 1];
607     std::vector<RhsBlock> packed_rhs_[P - 1];
608 
609     // If we choose to parallelize only by the sharding dimension, each thread
610     // will have it's own "thead local" (not a c++ thread local storage) memory
611     // for packed_lhs or packed_rhs (shard_by_col = false of true). This memory
612     // can't be passed to a kernel that might execute on a different thread.
613     //
614     // In practice when we are ready to pack memory for the sharding dimension
615     // (rhs if shard_by_col==true) of the K-th slice, all kernels for K-1 slice
616     // already computed (99% of the time), and we can pack data into the thread
617     // local storage, and guarantee that all the kernels will be executed
618     // immediately in the same thread. This significantly increases L1 cache hit
619     // ratio and reduces pressure on the memory bus.
620     //
621     // It's still possible that kernel for the K-th slice will be ready before
622     // completion of the K-1 kernel, so we have to allocate "global" packed_lhs_
623     // and packed_rhs_ to allow kernels to be executed later on a thread
624     // different from the thread that was used for packing.
625 
626     // Handle for pre-allocated thread local memory buffers.
627     BlockMemHandle thread_local_pre_alocated_mem_;
628 
629     // Only one of these will be initialized depending on shard_by_col value
630     // (the size will be `num_worker_threads * num_grains_in_the_sharding_dim`).
631     std::vector<LhsBlock> lhs_thread_local_pre_allocated_;
632     std::vector<RhsBlock> rhs_thread_local_pre_allocated_;
633 
634     // How many thread local blocks were already allocated.
635     std::atomic<int> num_thread_local_allocations_;
636     const int thread_local_capacity;
637 
638     // We will use pre-allocated Lhs/Rhs blocks defined above, if the number of
639     // unique threads in a system is below or equal to the number of threads in
640     // a thread pool. We will fallback on dynamic memory allocation after that.
641 
642     // ThreadLocalBlocks is a container for Lhs or Rhs thread local buffers. Its
643     // size is equal to the grain size in Lhs/Rhs sharding dimension.
644     template <typename BlockType>
645     class ThreadLocalBlocks {
646      public:
647       ThreadLocalBlocks() = default;
648 
649       ThreadLocalBlocks(BlockType* base, size_t grain_size)
650           : is_pre_allocated_(true),
651             thread_local_pre_allocated_base_(base),
652             grain_size_(grain_size) {}
653 
654       ThreadLocalBlocks(BlockMemHandle mem_handle,
655                         std::vector<BlockType> blocks)
656           : is_pre_allocated_(false),
657             mem_handle_(std::move(mem_handle)),
658             blocks_(std::move(blocks)) {}
659 
660       BlockType& block(int grain_index) {
661         eigen_assert(grain_index >= 0);
662         eigen_assert(static_cast<size_t>(grain_index) < size());
663         return is_pre_allocated_ ? thread_local_pre_allocated_base_[grain_index]
664                                  : blocks_[grain_index];
665       }
666 
667       void Release(EvalParallelContext& ctx) const {
668         if (!is_pre_allocated_) {
669           ctx.kernel_.deallocate(ctx.device_, mem_handle_);
670         }
671       }
672 
673       size_t size() const {
674         return is_pre_allocated_ ? grain_size_ : blocks_.size();
675       }
676 
677      private:
678       bool is_pre_allocated_;
679 
680       // Reuse pre-allocated thread local buffers.
681       BlockType* thread_local_pre_allocated_base_ = nullptr;
682       size_t grain_size_ = 0;
683 
684       // These will be initialized only if `is_pre_allocated == false`.
685       BlockMemHandle mem_handle_{};
686       std::vector<BlockType> blocks_;
687     };
688 
689     // ThreadLocalBlocksInitialize callable does custom thread local blocks
690     // initialization, and will reuse pre-allocated buffers if possible, or will
691     // dynamically allocate new memory.
692     //
693     // Lhs/Rhs blocks might be of the same type, so we have to pass explicitly
694     // for what side do we plan to do block allocation.
695     template <typename BlockType, bool is_rhs>
696     class ThreadLocalBlocksInitialize {
697       static constexpr bool kIsLhs =
698           !is_rhs && std::is_same<BlockType, LhsBlock>::value;
699       static const bool kIsRhs =
700           is_rhs && std::is_same<BlockType, RhsBlock>::value;
701       static_assert(kIsLhs || kIsRhs, "Unkown block type");
702 
703       using Blocks = ThreadLocalBlocks<BlockType>;
704 
705      public:
706       ThreadLocalBlocksInitialize(EvalParallelContext& ctx)
707           : ctx_(ctx),
708             num_worker_threads_(ctx_.device_.numThreadsInPool()) {}
709 
710       void operator()(Blocks& blocks) {
711         const int n = ctx_.num_thread_local_allocations_.fetch_add(
712             1, std::memory_order_relaxed);
713 
714         if (n >= num_worker_threads_) {
715           ThreadLocalBlocksAllocator<is_rhs>::allocate(ctx_, blocks);
716         } else {
717           ThreadLocalBlocksAllocator<is_rhs>::reuse(ctx_, n, blocks);
718         }
719       }
720 
721      private:
722       // NOTE(ezhulenev): Without 'if constexpr' we have to put calls to
723       // TensorContractionKernel::allocateSlices into template specializations.
724       // Also explicit specializations are not allowed at class scope in C++03,
725       // EvalCtx type parameter is just a workaround for that limitation.
726       template <bool pack_rhs, typename EvalCtx = EvalParallelContext>
727       struct ThreadLocalBlocksAllocator;
728 
729       template <typename EvalCtx>
730       struct ThreadLocalBlocksAllocator</*pack_rhs=*/true, EvalCtx> {
731         static void allocate(EvalCtx& ctx, Blocks& blocks) {
732           std::vector<RhsBlock> rhs_blocks;
733           BlockMemHandle mem_handle = ctx.kernel_.allocateSlices(
734               ctx.device_,
735               /*num_lhs=*/0,
736               /*num_rhs=*/ctx.gn_,
737               /*num_slices=*/1,
738               /*lhs_blocks=*/nullptr, /*rhs_blocks=*/&rhs_blocks);
739 
740           blocks = ThreadLocalBlocks<RhsBlock>(std::move(mem_handle),
741                                                std::move(rhs_blocks));
742         }
743 
744         static void reuse(EvalCtx& ctx, int index, Blocks& blocks) {
745           RhsBlock* ptr = &ctx.rhs_thread_local_pre_allocated_[ctx.gn_ * index];
746           blocks = ThreadLocalBlocks<RhsBlock>(ptr, ctx.gn_);
747         }
748       };
749 
750       template <typename EvalCtx>
751       struct ThreadLocalBlocksAllocator</*pack_rhs=*/false, EvalCtx> {
752         static void allocate(EvalCtx& ctx, Blocks& blocks) {
753           std::vector<LhsBlock> lhs_blocks;
754           BlockMemHandle mem_handle = ctx.kernel_.allocateSlices(
755               ctx.device_,
756               /*num_lhs=*/ctx.gm_,
757               /*num_rhs=*/0,
758               /*num_slices=*/1,
759               /*lhs_blocks=*/&lhs_blocks, /*rhs_blocks=*/nullptr);
760 
761           blocks = ThreadLocalBlocks<LhsBlock>(std::move(mem_handle),
762                                                std::move(lhs_blocks));
763         }
764 
765         static void reuse(EvalCtx& ctx, int index, Blocks& blocks) {
766           LhsBlock* ptr = &ctx.lhs_thread_local_pre_allocated_[ctx.gm_ * index];
767           blocks = ThreadLocalBlocks<LhsBlock>(ptr, ctx.gm_);
768         }
769       };
770 
771       EvalParallelContext& ctx_;
772       const int num_worker_threads_;
773     };
774 
775     template <typename BlockType>
776     class ThreadLocalBlocksRelease {
777      public:
778       using Blocks = ThreadLocalBlocks<BlockType>;
779       ThreadLocalBlocksRelease(EvalParallelContext& ctx) : ctx_(ctx) {}
780       void operator()(Blocks& blocks) { blocks.Release(ctx_); }
781 
782      private:
783       EvalParallelContext& ctx_;
784     };
785 
786     // ThreadLocalBlocks initialization callables.
787     using ThreadLocalLhsInit =
788         ThreadLocalBlocksInitialize<LhsBlock, /*is_rhs=*/false>;
789     using ThreadLocalRhsInit =
790         ThreadLocalBlocksInitialize<RhsBlock, /*is_rhs=*/true>;
791 
792     // ThreadLocalBlocks release callables.
793     using ThreadLocalLhsRelease = ThreadLocalBlocksRelease<LhsBlock>;
794     using ThreadLocalRhsRelease = ThreadLocalBlocksRelease<RhsBlock>;
795 
796     // Thread local containers for Lhs/Rhs block packs. In practice only one of
797     // them will be used, depending on the shard_by_col value.
798     Eigen::ThreadLocal<ThreadLocalBlocks<LhsBlock>, ThreadLocalLhsInit,
799                        ThreadLocalLhsRelease>
800         lhs_thread_local_blocks_;
801     Eigen::ThreadLocal<ThreadLocalBlocks<RhsBlock>, ThreadLocalRhsInit,
802                        ThreadLocalRhsRelease>
803         rhs_thread_local_blocks_;
804 
805     // After a particular shard for Kth slice missed thread local execution
806     // opportunity (K-1 slice didn't complete kernels execution), we can no
807     // longer schedule K+1 and following slices in thread local mode, because
808     // there is no more guarantee that previous kernels were executed
809     // sequentially in the same thread (size is nn_ or nm_).
810     std::atomic<bool>* can_use_thread_local_packed_;
811 
812     std::atomic<uint8_t>** state_kernel_[P];
813     // state_switch_ is frequently modified by worker threads, while other
814     // fields are read-only after constructor. Let's move it to a separate cache
815     // line to reduce cache-coherency traffic.
816     char pad_[128];
817     std::atomic<Index> state_packing_ready_[P];
818     std::atomic<Index> state_switch_[P];
819 
820     LhsBlock& packed_lhs(Index m, Index k, Index m1, bool use_thread_local) {
821       if (use_thread_local) {
822         eigen_assert(!shard_by_col_);
823         ThreadLocalBlocks<LhsBlock>& blocks = lhs_thread_local_blocks_.local();
824 
825         Index grain_index = m1 - m * gm_;
826         return blocks.block(internal::convert_index<int>(grain_index)); // FIXME better make ThreadLocalBlocks use Eigen::Index?
827       } else {
828         return packed_lhs_[k % (P - 1)][m1];
829       }
830     }
831 
832     RhsBlock& packed_rhs(Index n, Index k, Index n1, bool use_thread_local) {
833       if (use_thread_local) {
834         eigen_assert(shard_by_col_);
835         ThreadLocalBlocks<RhsBlock>& blocks = rhs_thread_local_blocks_.local();
836 
837         Index grain_index = n1 - n * gn_;
838         return blocks.block(internal::convert_index<int>(grain_index)); // FIXME better make ThreadLocalBlocks use Eigen::Index?
839       } else {
840         return packed_rhs_[k % (P - 1)][n1];
841       }
842     }
843 
844     // In following two methods (pack_lhs and pack_rhs), if we know for sure
845     // that we'll be able to immediately call a kernel with packed data, and do
846     // not submit it to the thread pool, we can use thread local memory for
847     // packed data.
848     //
849     // We can only reliably check it if we are running all kernels in sync mode
850     // (parallelize only by sharding dim). If kernel for m==0 (n==0) is ready to
851     // run, it's guaranteed that all kernels with larger values of m (n) are
852     // also ready, because we execute them in the same order for all K slices.
853 
854     void pack_lhs(Index m, Index k) {
855       bool use_thread_local = false;
856 
857       if (parallelize_by_sharding_dim_only_ && !shard_by_col_ &&
858           can_use_thread_local_packed_[m].load(std::memory_order_relaxed)) {
859         if (state_kernel_[k % P][m][0].load(std::memory_order_relaxed) == 1) {
860           use_thread_local = true;
861         } else {
862           // If we can't guarantee that all kernels in `k` slice will be
863           // executed sequentially in current thread, it's no longer safe to use
864           // thread local memory in following slices along the k dimensions.
865           eigen_assert(k > 0);
866           can_use_thread_local_packed_[m].store(false,
867                                                 std::memory_order_relaxed);
868         }
869       }
870 
871       const Index mend = m * gm_ + gm(m);
872       for (Index m1 = m * gm_; m1 < mend; m1++)
873         kernel_.packLhs(&packed_lhs(m, k, m1, use_thread_local),
874                         lhs_.getSubMapper(m1 * bm_, k * bk_), bk(k), bm(m1));
875 
876       if (!parallel_pack_ && shard_by_col_) {
877         assert(!use_thread_local);
878         signal_packing(k);
879       } else {
880         signal_switch(k + 1);
881         for (Index n = nn_ - 1; n >= 0; n--) {
882           bool sync = parallelize_by_sharding_dim_only_ || n == 0;
883           signal_kernel(m, n, k, sync, use_thread_local);
884         }
885       }
886     }
887 
888     void pack_rhs(Index n, Index k) {
889       bool use_thread_local = false;
890 
891       if (parallelize_by_sharding_dim_only_ && shard_by_col_ &&
892           can_use_thread_local_packed_[n].load(std::memory_order_relaxed)) {
893         if (state_kernel_[k % P][0][n].load(std::memory_order_relaxed) == 1) {
894           use_thread_local = true;
895         } else {
896           // If we can't guarantee that all kernels in `k` slice will be
897           // executed sequentially in current thread, it's no longer safe to use
898           // thread local memory in followig slices along the k dimensions.
899           eigen_assert(k > 0);
900           can_use_thread_local_packed_[n].store(false,
901                                                 std::memory_order_relaxed);
902         }
903       }
904 
905       const Index nend = n * gn_ + gn(n);
906       for (Index n1 = n * gn_; n1 < nend; n1++) {
907         if (!TensorContractionKernel::HasBeta && k == 0) {
908           // Zero the output memory in parallel, only if contraction kernel does
909           // not support `beta`. Otherwise we will pass beta 0.0 to the first
910           // call to the `TensorContractionKernel::invoke()`.
911           //
912           // On 10000x2x10000 mm zeroing can easily take half of time. Zero (bn
913           // x m) row. Safe to do here because all kernels that will write to
914           // this memory depend on completion of this task. Note: don't call
915           // device_.memset() here. device_.memset() blocks on thread pool
916           // worker thread, which can lead to underutilization and deadlocks.
917           memset(buffer_ + n1 * bn_ * m_, 0, bn(n1) * m_ * sizeof(Scalar));
918         }
919         kernel_.packRhs(&packed_rhs(n, k, n1, use_thread_local),
920                         rhs_.getSubMapper(k * bk_, n1 * bn_), bk(k), bn(n1));
921       }
922 
923       if (parallel_pack_ || shard_by_col_) {
924         signal_switch(k + 1);
925         for (Index m = nm_ - 1; m >= 0; m--) {
926           bool sync = parallelize_by_sharding_dim_only_ || m == 0;
927           signal_kernel(m, n, k, sync, use_thread_local);
928         }
929       } else {
930         assert(!use_thread_local);
931         signal_packing(k);
932       }
933     }
934 
935     void kernel(Index m, Index n, Index k, bool use_thread_local) {
936       // Note: order of iteration matters here. Iteration over m is innermost
937       // because we want to reuse the same packed rhs in consecutive tasks
938       // (rhs fits into L2$ while lhs only into L3$).
939       const Index nend = n * gn_ + gn(n);
940       const Index mend = m * gm_ + gm(m);
941 
942       // NOTE: output = alpha * LHS * RHS + beta * output.
943       const Scalar alpha = Scalar(1);
944       const Scalar beta =
945           (TensorContractionKernel::HasBeta && k == 0) ? Scalar(0) : Scalar(1);
946 
947       if (shard_by_col_) {
948         for (Index n1 = n * gn_; n1 < nend; n1++) {
949           for (Index m1 = m * gm_; m1 < mend; m1++) {
950             const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
951             kernel_.invoke(
952                 output_mapper,
953                 packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
954                 packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
955                 bk(k), bn(n1), alpha, beta);
956 
957             // We are done with the last task for the [m1, n1] block.
958             if (k + 1 == nk_) {
959               output_kernel_(output_mapper, tensor_contraction_params_,
960                              m1 * bm_, n1 * bn_, bm(m1), bn(n1));
961             }
962           }
963         }
964       } else {
965         for (Index m1 = m * gm_; m1 < mend; m1++)
966           for (Index n1 = n * gn_; n1 < nend; n1++) {
967             const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
968             kernel_.invoke(
969                 output_mapper,
970                 packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
971                 packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
972                 bk(k), bn(n1), alpha, beta);
973 
974             // We are done with the last task for the [m1, n1] block.
975             if (k + 1 == nk_) {
976               output_kernel_(output_mapper, tensor_contraction_params_,
977                              m1 * bm_, n1 * bn_, bm(m1), bn(n1));
978             }
979           }
980       }
981       signal_kernel(m, n, k + 1, /*sync=*/false, /*use_thread_local=*/false);
982       signal_switch(k + 2);
983     }
984 
985     void signal_packing(Index k) {
986       eigen_assert(!parallel_pack_);
987       Index s = state_packing_ready_[k % P].fetch_sub(1);
988       eigen_assert(s > 0);
989       if (s != 1) return;
990       state_packing_ready_[k % P] = shard_by_col_ ? nm_ : nn_;
991       enqueue_packing(k, shard_by_col_);
992     }
993 
994     void signal_kernel(Index m, Index n, Index k, bool sync,
995                        bool use_thread_local) {
996       std::atomic<uint8_t>* state = &state_kernel_[k % P][m][n];
997       Index s = state->load();
998       eigen_assert(s > 0);
999       if (s != 1 && state->fetch_sub(1) != 1) {
1000         eigen_assert(!use_thread_local);
1001         return;
1002       }
1003       state->store(parallel_pack_ ? 3 : 2, std::memory_order_relaxed);
1004       if (sync) {
1005         kernel(m, n, k, use_thread_local);
1006       } else {
1007         eigen_assert(!use_thread_local);
1008         device_.enqueueNoNotification(
1009             [=]() { kernel(m, n, k, use_thread_local); });
1010       }
1011     }
1012 
1013     void signal_switch(Index k, Index v = 1) {
1014       Index s = state_switch_[k % P].fetch_sub(v);
1015       eigen_assert(s >= v);
1016       if (s != v) return;
1017 
1018       // Ready to switch to the next k slice.
1019       // Reset counter for the next iteration.
1020       state_switch_[k % P] =
1021           (parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_)) +
1022           nm_ * nn_;
1023       if (k < nk_) {
1024         // Issue lhs/rhs packing. Their completion will in turn kick off
1025         // kernels.
1026         if (parallel_pack_) {
1027           enqueue_packing(k, !shard_by_col_);
1028           enqueue_packing(k, shard_by_col_);
1029         } else if (shard_by_col_) {
1030           enqueue_packing(k, false);
1031         } else {
1032           enqueue_packing(k, true);
1033         }
1034 
1035         // Termination handling.
1036         // Because kernel completion signals k + 2 switch, we need to finish nk
1037         // + 2 slices without issuing any tasks on nk + 1 slice. So here we
1038         // pretend that all nk + 1 packing tasks just finish instantly; so that
1039         // nk + 2 switch only waits for completion of nk kernels.
1040       } else if (k == nk_) {
1041         signal_switch(k + 1,
1042                       parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_));
1043       } else {
1044         done_.Notify();
1045       }
1046     }
1047 
1048     // Enqueue all rhs/lhs packing for k-th slice.
1049     void enqueue_packing(Index k, bool rhs) {
1050       enqueue_packing_helper(0, rhs ? nn_ : nm_, k, rhs);
1051     }
1052 
1053     void enqueue_packing_helper(Index start, Index end, Index k, bool rhs) {
1054       if (end - start == 1) {
1055         if (rhs)
1056           pack_rhs(start, k);
1057         else
1058           pack_lhs(start, k);
1059       } else {
1060         while (end - start > 1) {
1061           Index mid = (start + end) / 2;
1062           device_.enqueueNoNotification(
1063               [=]() { enqueue_packing_helper(mid, end, k, rhs); });
1064           end = mid;
1065         }
1066 
1067         // Decide if we want to run first packing task (start == 0) in
1068         // async mode if we parallelize only by sharding dim:
1069         // (1) pack_lhs and pack_rhs call signal_switch before completing
1070         //     all calls to signal_kernel, which in sync mode might lead
1071         //     to the execution of the first kernel of the k+1 slice, before
1072         //     completing a call to the last kernel of the k slice.
1073         // (2) all pack tasks for sharded dim must be executed in a thread
1074         //     pool to get pre-allocated thead local buffers.
1075         bool pack_async =
1076           (start == 0) &&
1077           (parallelize_by_sharding_dim_only_&& shard_by_col_ == rhs) &&
1078           (k > 0 || std::this_thread::get_id() == created_by_thread_id_);
1079 
1080         if (pack_async) {
1081           device_.enqueueNoNotification(
1082               [=]() { enqueue_packing_helper(start, end, k, rhs); });
1083         } else {
1084           enqueue_packing_helper(start, end, k, rhs);
1085         }
1086       }
1087     }
1088 
1089     // Block sizes with accounting for potentially incomplete last block.
1090     Index bm(Index m) const { return m + 1 < nm0_ ? bm_ : m_ + bm_ - bm_ * nm0_; }
1091     Index bn(Index n) const { return n + 1 < nn0_ ? bn_ : n_ + bn_ - bn_ * nn0_; }
1092     Index bk(Index k) const { return k + 1 < nk_ ? bk_ : k_ + bk_ - bk_ * nk_; }
1093     // Task grain sizes accounting for potentially incomplete last task.
1094     Index gm(Index m) const { return m + 1 < nm_ ? gm_ : nm0_ + gm_ - gm_ * nm_; }
1095     Index gn(Index n) const { return n + 1 < nn_ ? gn_ : nn0_ + gn_ - gn_ * nn_; }
1096 
1097     EvalParallelContext(const EvalParallelContext&) = delete;
1098     void operator=(const EvalParallelContext&) = delete;
1099   };
1100 
1101   template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
1102             bool rhs_inner_dim_reordered, int Alignment>
1103   using SyncEvalParallelContext =
1104       EvalParallelContext<NoCallback, lhs_inner_dim_contiguous,
1105                           rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
1106                           Alignment>;
1107 
1108   // ------------------------------------------------------------------------ //
1109 
1110   // EvalShardedByInnerDimContext orchestrates sync/async contraction
1111   // evaluation, when we shard by inner dimension. When it is executed in
1112   // asynchronous mode, it owns all the shared state that might be accessible by
1113   // block processing tasks.
1114 
1115   template <typename DoneCallback>
1116   struct EvalShardedByInnerDimContext {
1117     EvalShardedByInnerDimContext(const Self* self, int num_threads,
1118                                  Scalar* result_buffer,
1119                                  Index m_size, Index n_size, Index k_size,
1120                                  DoneCallback done_callback)
1121         : evaluator(self),
1122           m_lhs_inner_dim_contiguous(evaluator->m_lhs_inner_dim_contiguous),
1123           m_rhs_inner_dim_contiguous(evaluator->m_rhs_inner_dim_contiguous),
1124           m_rhs_inner_dim_reordered(evaluator->m_rhs_inner_dim_reordered),
1125           result(result_buffer),
1126           m(m_size),
1127           n(n_size),
1128           k(k_size),
1129           done(std::move(done_callback)),
1130           buffer_size_bytes(m * n * sizeof(Scalar)),
1131           block_size(blockSize(k, num_threads)),
1132           num_blocks(divup<Index>(k, block_size)),
1133           num_pending_blocks(internal::convert_index<int>(num_blocks)),
1134           l0_ranges(divup<Index>(num_blocks, l0_size)),
1135           l0_state(l0_ranges),
1136           block_buffers(num_blocks) {
1137       // Keep count of pending gemm tasks for each l0 range.
1138       for (int i = 0; i < l0_ranges; ++i) {
1139         const Index num_pending_tasks = actualRangeSize(l0_ranges, l0_size, i);
1140         l0_state.emplace_back(internal::convert_index<int>(num_pending_tasks));
1141       }
1142 
1143       // Allocate temporary buffers for each block.
1144       for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
1145         Scalar* buf = block_idx == 0
1146                           ? result
1147                           : static_cast<Scalar*>(evaluator->m_device.allocate(
1148                                 buffer_size_bytes));
1149         block_buffers.emplace_back(buf);
1150       }
1151     }
1152 
1153     ~EvalShardedByInnerDimContext() {
1154       for (Index i = 1; i < num_blocks; ++i) {
1155         evaluator->m_device.deallocate(block_buffers[i]);
1156       }
1157     }
1158 
1159     template <int Alignment>
1160     void run() {
1161       Barrier barrier(internal::convert_index<int>(num_blocks));
1162       eval<Alignment>(barrier, 0, num_blocks);
1163       barrier.Wait();
1164 
1165       // Aggregate partial sums from l0 ranges.
1166       aggregateL0Blocks<Alignment>();
1167 
1168       // Apply output kernel.
1169       applyOutputKernel();
1170     }
1171 
1172     template <int Alignment>
1173     void runAsync() {
1174       evalAsync<Alignment>(0, num_blocks);
1175     }
1176 
1177    private:
1178     // The underlying GEMM kernel assumes that k is a multiple of
1179     // the packet size and subtle breakage occurs if this is violated.
1180     static const Index packet_size = internal::packet_traits<RhsScalar>::size;
1181 
1182     const Self* evaluator;  // TensorContraction evaluator
1183 
1184     // These fields required fromTENSOR_CONTRACTION_DISPATCH macro.
1185     bool m_lhs_inner_dim_contiguous;
1186     bool m_rhs_inner_dim_contiguous;
1187     bool m_rhs_inner_dim_reordered;
1188 
1189     Scalar* result;
1190 
1191     Index m;
1192     Index n;
1193     Index k;
1194 
1195     DoneCallback done;
1196 
1197     // ----------------------------------------------------------------------//
1198     // Algorithm parameters.
1199 
1200     // We will compute partial results into the buffers of this size.
1201     Index buffer_size_bytes;
1202 
1203     Index block_size;
1204     Index num_blocks;
1205 
1206     // Keep track of pending tasks when evaluate in async mode.
1207     std::atomic<int> num_pending_blocks;
1208 
1209     // We compute partial gemm results in parallel, and to get the final result
1210     // we need to add them all together. For the large number of threads (>= 48)
1211     // this adds a very expensive sequential step at the end.
1212     //
1213     // We split the [0, num_blocks) into small ranges, and when a task for the
1214     // block finishes its partial gemm computation, it checks if it was the last
1215     // gemm in the range, and if so, it will add all blocks of the range.
1216     //
1217     // After all tasks done, we need to add only these pre-aggregated blocks.
1218 
1219     // For now we use just a single level of ranges to compute pre-aggregated
1220     // partial sums, but in general we can use more layers to compute tree
1221     // aggregation in parallel and reduce the size of the sequential step.
1222     //
1223     // TODO(ezhulenev): Add multilevel tree aggregation? Probably will make
1224     // sense only if number of threads >= ~128?
1225     static const Index l0_size = 4;
1226     Index l0_ranges;
1227 
1228     // Keep count of pending gemm tasks for each l0 range.
1229     MaxSizeVector<std::atomic<int>> l0_state;  // [0, l0_ranges)
1230 
1231     // Buffers allocated for each temporary block computation.
1232     MaxSizeVector<Scalar*> block_buffers;  // [0, num_blocks)
1233 
1234     template <int Alignment>
1235     void processBlock(Index block_idx, Index begin, Index end) {
1236       Scalar* buf = block_buffers[block_idx];
1237 
1238       TENSOR_CONTRACTION_DISPATCH(
1239           evaluator->template evalGemmPartialWithoutOutputKernel, Alignment,
1240           (buf, begin, end,
1241            /*num_threads=*/internal::convert_index<int>(num_blocks)));
1242 
1243       // Check if it was the last task in l0 range.
1244       const Index l0_index = block_idx / l0_size;
1245       const int v = l0_state[l0_index].fetch_sub(1);
1246       eigen_assert(v >= 1);
1247 
1248       // If we processed the last block of the range, we can aggregate all
1249       // partial results into the first block of the range.
1250       if (v == 1) {
1251         const Index rng_size = actualRangeSize(l0_ranges, l0_size, l0_index);
1252         const Index dst_block_idx = l0_index * l0_size;
1253 
1254         if (rng_size == l0_size) {
1255           addAllToBuffer<Alignment>(
1256               m * n,
1257               /*src_buf0=*/block_buffers[dst_block_idx + 1],
1258               /*src_buf1=*/block_buffers[dst_block_idx + 2],
1259               /*src_buf2=*/block_buffers[dst_block_idx + 3],
1260               /*dst_buf= */ block_buffers[dst_block_idx]);
1261         } else {
1262           // Aggregate blocks of potentially incomplete last range.
1263           for (int i = 1; i < rng_size; ++i) {
1264             addToBuffer<Alignment>(m * n,
1265                                    /*src_buf=*/block_buffers[dst_block_idx + i],
1266                                    /*dst_buf=*/block_buffers[dst_block_idx]);
1267           }
1268         }
1269       }
1270     }
1271 
1272     // Aggregate partial sums from l0 ranges.
1273     template <int Alignment>
1274     void aggregateL0Blocks() const {
1275       Index l0_index = 1;
1276 
1277       for (; l0_index + 2 < l0_ranges; l0_index += 3) {
1278         addAllToBuffer<Alignment>(
1279             m * n,
1280             /*src_buf0=*/block_buffers[(l0_index + 0) * l0_size],
1281             /*src_buf1=*/block_buffers[(l0_index + 1) * l0_size],
1282             /*src_buf2=*/block_buffers[(l0_index + 2) * l0_size],
1283             /*dst_buf= */ block_buffers[0]);
1284       }
1285 
1286       for (; l0_index < l0_ranges; ++l0_index) {
1287         addToBuffer<Alignment>(m * n, block_buffers[l0_index * l0_size],
1288                                block_buffers[0]);
1289       }
1290     }
1291 
1292     void applyOutputKernel() const {
1293       typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
1294       evaluator->m_output_kernel(
1295           OutputMapper(result, m), evaluator->m_tensor_contraction_params,
1296           static_cast<Eigen::Index>(0), static_cast<Eigen::Index>(0), m, n);
1297     }
1298 
1299     // Compute block size with accounting for potentially incomplete last block.
1300     Index actualBlockSize(Index block_idx) const {
1301       return block_idx + 1 < num_blocks
1302                  ? block_size
1303                  : k + block_size - block_size * num_blocks;
1304     };
1305 
1306     // Compute range size with accounting for potentially incomplete last range.
1307     Index actualRangeSize(Index num_ranges, Index range_size,
1308                           Index range_idx) const {
1309       eigen_assert(range_idx < num_ranges);
1310       return range_idx + 1 < num_ranges
1311                  ? range_size
1312                  : num_blocks + range_size - range_size * num_ranges;
1313     };
1314 
1315     template <int Alignment>
1316     EIGEN_STRONG_INLINE static void addToBuffer(size_t n, const Scalar* src_buf,
1317                                                 Scalar* tgt_buf) {
1318       const int output_packet_size =
1319           internal::unpacket_traits<PacketReturnType>::size;
1320       size_t i = 0;
1321       const size_t num_packets = n / output_packet_size;
1322       for (; i < output_packet_size * num_packets; i += output_packet_size) {
1323         const PacketReturnType src_val =
1324             internal::pload<PacketReturnType>(src_buf + i);
1325         const PacketReturnType tgt_val =
1326             internal::ploadt<PacketReturnType, Alignment>(tgt_buf + i);
1327         const PacketReturnType sum = internal::padd(src_val, tgt_val);
1328         internal::pstoret<Scalar, PacketReturnType, Alignment>(tgt_buf + i,
1329                                                                sum);
1330       }
1331       for (; i < n; ++i) {
1332         tgt_buf[i] += src_buf[i];
1333       }
1334     }
1335 
1336     template <int Alignment>
1337     EIGEN_STRONG_INLINE static void addAllToBuffer(size_t n,
1338                                                    const Scalar* src_buf0,
1339                                                    const Scalar* src_buf1,
1340                                                    const Scalar* src_buf2,
1341                                                    Scalar* dst_buf) {
1342       using ::Eigen::internal::padd;
1343       using ::Eigen::internal::pload;
1344       using ::Eigen::internal::ploadt;
1345       using ::Eigen::internal::pstoret;
1346 
1347       const int output_packet_size =
1348           internal::unpacket_traits<PacketReturnType>::size;
1349 
1350       size_t i = 0;
1351       const size_t num_packets = n / output_packet_size;
1352       for (; i < output_packet_size * num_packets; i += output_packet_size) {
1353         const auto src_val0 = pload<PacketReturnType>(src_buf0 + i);
1354         const auto src_val1 = pload<PacketReturnType>(src_buf1 + i);
1355         const auto src_val2 = pload<PacketReturnType>(src_buf2 + i);
1356 
1357         const auto dst_val = ploadt<PacketReturnType, Alignment>(dst_buf + i);
1358         const auto sum =
1359             padd(padd(dst_val, src_val0), padd(src_val1, src_val2));
1360 
1361         pstoret<Scalar, PacketReturnType, Alignment>(dst_buf + i, sum);
1362       }
1363       for (; i < n; ++i) {
1364         dst_buf[i] += src_buf0[i] + src_buf1[i] + src_buf2[i];
1365       }
1366     }
1367 
1368     template <int Alignment>
1369     void eval(Barrier& barrier, Index start_block_idx, Index end_block_idx) {
1370       while (end_block_idx - start_block_idx > 1) {
1371         Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
1372         evaluator->m_device.enqueueNoNotification(
1373             [this, &barrier, mid_block_idx, end_block_idx]() {
1374               eval<Alignment>(barrier, mid_block_idx, end_block_idx);
1375             });
1376         end_block_idx = mid_block_idx;
1377       }
1378 
1379       Index block_idx = start_block_idx;
1380       Index block_start = block_idx * block_size;
1381       Index block_end = block_start + actualBlockSize(block_idx);
1382 
1383       processBlock<Alignment>(block_idx, block_start, block_end);
1384       barrier.Notify();
1385     }
1386 
1387     template <int Alignment>
1388     void evalAsync(Index start_block_idx, Index end_block_idx) {
1389       while (end_block_idx - start_block_idx > 1) {
1390         Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
1391         evaluator->m_device.enqueueNoNotification(
1392             [this, mid_block_idx, end_block_idx]() {
1393               evalAsync<Alignment>(mid_block_idx, end_block_idx);
1394             });
1395         end_block_idx = mid_block_idx;
1396       }
1397 
1398       Index block_idx = start_block_idx;
1399 
1400       Index block_start = block_idx * block_size;
1401       Index block_end = block_start + actualBlockSize(block_idx);
1402 
1403       processBlock<Alignment>(block_idx, block_start, block_end);
1404 
1405       int v = num_pending_blocks.fetch_sub(1);
1406       eigen_assert(v >= 1);
1407 
1408       if (v == 1) {
1409         // Aggregate partial sums from l0 ranges.
1410         aggregateL0Blocks<Alignment>();
1411 
1412         // Apply output kernel.
1413         applyOutputKernel();
1414 
1415         // NOTE: If we call `done` callback before deleting this (context),
1416         // it might deallocate Self* pointer captured by context, and we'll
1417         // fail in destructor trying to deallocate temporary buffers.
1418 
1419         // Move done call back from context before it will be destructed.
1420         DoneCallback done_copy = std::move(done);
1421 
1422         // We are confident that we are the last one who touches context.
1423         delete this;
1424 
1425         // Now safely call the done callback.
1426         done_copy();
1427       }
1428     }
1429 
1430     // Cost model doesn't capture well the cost associated with constructing
1431     // tensor contraction mappers and computing loop bounds in gemm_pack_lhs
1432     // and gemm_pack_rhs, so we specify minimum desired block size.
1433     static Index blockSize(Index k, int num_threads) {
1434       const auto round_up = [=](Index index) -> Index {
1435         const Index kmultiple = packet_size <= 8 ? 8 : packet_size;
1436         return divup<Index>(index, kmultiple) * kmultiple;
1437       };
1438 
1439       const Index target_block_size = round_up(divup<Index>(k, num_threads));
1440       const Index desired_min_block_size = 12 * packet_size;
1441 
1442       return numext::mini<Index>(
1443           k, numext::maxi<Index>(desired_min_block_size, target_block_size));
1444     }
1445 
1446     EvalShardedByInnerDimContext(const EvalShardedByInnerDimContext&) = delete;
1447     void operator=(const EvalShardedByInnerDimContext&) = delete;
1448   };
1449 
1450   // ------------------------------------------------------------------------ //
1451 
1452   // Below are the function used by evalProductImpl heuristics, trying to select
1453   // optimcal parameters for parallelization algorithm.
1454 
1455   // Decide whether we want to shard m x n contraction by columns or by rows.
1456   static bool shardByCol(Index m, Index n, Index num_threads) {
1457     // Note: we are comparing both n and m against Traits::nr, it is not
1458     // a mistake. We are trying to figure out how both n and m will fit into
1459     // the main sharding dimension.
1460 
1461     // Sharding by column is the default
1462     // ... unless there is enough data for vectorization over rows
1463     if (m / num_threads >= Traits::nr &&
1464         // and not enough data for vectorization over columns
1465         (n / num_threads < Traits::nr ||
1466          // ... or barely enough data for vectorization over columns,
1467          // but it is not evenly dividable across threads
1468          (n / num_threads < 4 * Traits::nr &&
1469           (n % (num_threads * Traits::nr)) != 0 &&
1470           // ... and it is evenly dividable across threads for rows
1471           ((m % (num_threads * Traits::nr)) == 0 ||
1472            // .. or it is not evenly dividable for both dimensions but
1473            // there is much more data over rows so that corner effects are
1474            // mitigated.
1475            (m / n >= 6)))))
1476       return false;
1477     // Wait, or if matrices are just substantially prolonged over the other
1478     // dimension.
1479     if (n / num_threads < 16 * Traits::nr && m > n * 32) return false;
1480     return true;
1481   }
1482 
1483   Index coarsenM(Index m, Index n, Index bm, Index bn, Index bk, Index gn,
1484                  int num_threads, bool shard_by_col) const {
1485     Index gm = 1;
1486     Index gm1 = 1;
1487     Index nm0 = divup(m, bm);
1488     Index nm1 = nm0;
1489     for (;;) {
1490       // Find the next candidate for m grain size. It needs to result in
1491       // different number of blocks. E.g. if we have 10 kernels, we want to try
1492       // 5 and 10, but not 6, 7, 8 and 9.
1493       while (gm1 <= nm0 && nm1 == divup(nm0, gm1)) gm1++;
1494       if (gm1 > nm0) break;
1495       // Check the candidate.
1496       int res = checkGrain(m, n, bm, bn, bk, gm1, gn, gm, gn, num_threads,
1497                            shard_by_col);
1498       if (res < 0) break;
1499       nm1 = divup(nm0, gm1);
1500       if (res == 0) continue;
1501       // Commit new grain size.
1502       gm = gm1;
1503     }
1504     return gm;
1505   }
1506 
1507   Index coarsenN(Index m, Index n, Index bm, Index bn, Index bk, Index gm,
1508                  int num_threads, bool shard_by_col) const {
1509     Index gn = 1;
1510     Index gn1 = 1;
1511     Index nn0 = divup(n, bn);
1512     Index nn1 = nn0;
1513     for (;;) {
1514       while (gn1 <= nn0 && nn1 == divup(nn0, gn1)) gn1++;
1515       if (gn1 > nn0) break;
1516       int res = checkGrain(m, n, bm, bn, bk, gm, gn1, gm, gn, num_threads,
1517                            shard_by_col);
1518       if (res < 0) break;
1519       nn1 = divup(nn0, gn1);
1520       if (res == 0) continue;
1521       gn = gn1;
1522     }
1523     return gn;
1524   }
1525 
1526   // checkGrain checks whether grain (gm, gn) is suitable and is better than
1527   // (oldgm, oldgn).
1528   int checkGrain(Index m, Index n, Index bm, Index bn, Index bk, Index gm,
1529                  Index gn, Index oldgm, Index oldgn, int num_threads,
1530                  bool shard_by_col) const {
1531     const TensorOpCost cost =
1532         contractionCost(bm * gm, bn * gn, bm, bn, bk, shard_by_col, true);
1533     double taskSize = TensorCostModel<ThreadPoolDevice>::taskSize(
1534         static_cast<double>(bm) * gm * bn * gn, cost);
1535     // If the task is too small, then we agree on it regardless of anything
1536     // else. Otherwise synchronization overheads will dominate.
1537     if (taskSize < 1) return 1;
1538     // If it is too large, then we reject it and all larger tasks.
1539     if (taskSize > 2) return -1;
1540     // Now we are in presumably good task size range.
1541     // The main deciding factor here is parallelism. Consider that we have 12
1542     // kernels and 4 threads. Grains of 2, 3 and 4 all yield good task sizes.
1543     // But 2/4 yield 6/3 tasks, which gives us parallelism of 0.75 (at most 3/4
1544     // of cores will be busy). While grain size 3 gives us 4 tasks, which gives
1545     // us parallelism of 1 (we can load all cores).
1546     Index nm0 = divup(m, bm);
1547     Index nn0 = divup(n, bn);
1548     Index new_tasks = divup(nm0, gm) * divup(nn0, gn);
1549     double new_parallelism = static_cast<double>(new_tasks) /
1550                              (divup<int>(new_tasks, num_threads) * num_threads);
1551     Index old_tasks = divup(nm0, oldgm) * divup(nn0, oldgn);
1552     double old_parallelism = static_cast<double>(old_tasks) /
1553                              (divup<int>(old_tasks, num_threads) * num_threads);
1554     if (new_parallelism > old_parallelism || new_parallelism == 1) return 1;
1555     return 0;
1556   }
1557 
1558   TensorOpCost contractionCost(Index m, Index n, Index bm, Index bn, Index bk,
1559                                bool shard_by_col, bool prepacked) const {
1560     const int packed_size = std::min<int>(PacketType<LhsScalar, Device>::size,
1561                                           PacketType<RhsScalar, Device>::size);
1562     const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
1563     const double kd = static_cast<double>(bk);
1564     double compute_bandwidth = computeBandwidth(false, bm, bn, bk);
1565     // Computations.
1566     TensorOpCost cost = TensorOpCost(0, 0, kd * compute_bandwidth, true, packed_size);
1567     // Output stores.
1568     cost += TensorOpCost(0, sizeof(CoeffReturnType), 0, true, output_packet_size);
1569     if (prepacked) {
1570       // Packing and kernels are executed in different tasks. When we calculate
1571       // task grain size we look only at kernel cost assuming that kernel
1572       // is more expensive than packing.
1573       return cost;
1574     }
1575     // Lhs/rhs loads + computations.
1576     TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(true) * (kd / n);
1577     TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(true) * (kd / m);
1578     // Lhs packing memory cost does not contribute considerably to overall
1579     // execution time because lhs is prefetched early and accessed sequentially.
1580     if (shard_by_col)
1581       lhsCost.dropMemoryCost();
1582     else
1583       rhsCost.dropMemoryCost();
1584     return cost + lhsCost + rhsCost;
1585   }
1586 
1587   // Decide whether we want to shard m x k x n contraction over the inner
1588   // (contraction) dimension (k).
1589   static bool shardByInnerDim(Index m, Index n, Index k, int num_threads,
1590                               int num_threads_by_k) {
1591     std::ptrdiff_t bufsize = m * n * sizeof(Scalar);
1592     bool shard_by_k = false;
1593     if (n == 1 ||                // If mat*vec or...
1594         num_threads_by_k < 2 ||  // running single threaded or...
1595         num_threads_by_k <
1596             num_threads ||  // sharding by k gives less parallelism or...
1597         bufsize > l3CacheSize() / num_threads_by_k ||  // need more buffer space
1598         // than L3 cache or...
1599         k / num_threads_by_k < 2 * Traits::nr) {  // k per thread is tiny.
1600       shard_by_k = false;
1601     } else if (numext::maxi(m, n) / num_threads <
1602                    Traits::nr ||  // both other dimensions are tiny or...
1603                // k per thread is not small and...
1604                (k / num_threads_by_k > 8 * Traits::nr &&
1605                 // one of the outer dimensions is tiny or sharding by k offers
1606                 // more parallelism.
1607                 (numext::mini(m, n) < 2 * Traits::nr ||
1608                  num_threads_by_k > num_threads))) {
1609       shard_by_k = true;
1610     }
1611     return shard_by_k;
1612   }
1613 
1614   TensorOpCost contractionCostPerInnerDim(Index m, Index n, Index k) const {
1615     // Compute cost.
1616     const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
1617     TensorOpCost cost(0, 0, (computeBandwidth(true, m, n, k) * m) * n, true, output_packet_size);
1618     // Output stores.
1619     cost += TensorOpCost(0, sizeof(CoeffReturnType), 0, true, output_packet_size);
1620     TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(true) * m;
1621     TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(true) * n;
1622     // Since the inner gemm kernel is always sharded by column, the lhs
1623     // load cost is negligible.
1624     lhsCost.dropMemoryCost();
1625     return cost + lhsCost + rhsCost;
1626   }
1627 
1628   int numThreadsInnerDim(Index m, Index n, Index k) const {
1629     const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
1630     TensorOpCost cost = contractionCostPerInnerDim(m, n, k);
1631     double total_parallel_cost =
1632         TensorCostModel<ThreadPoolDevice>::totalCost(k, cost);
1633     // Cost of reduction step accumulating the m*n per-thread buffers into the
1634     // result.
1635     double reduction_cost = TensorCostModel<ThreadPoolDevice>::totalCost(
1636         m * n, TensorOpCost(2, 1, 1, true, output_packet_size));
1637     int num_threads = 1;
1638     double min_cost = total_parallel_cost;
1639     double kPerThreadOverHead = 3000;
1640     double kFixedOverHead = 100000;
1641     for (int nt = 2; nt <= this->m_device.numThreads(); nt += 2) {
1642       double sequential_cost =
1643           kFixedOverHead + nt * (reduction_cost + kPerThreadOverHead);
1644       double parallel_cost = total_parallel_cost / nt + sequential_cost;
1645       if (parallel_cost < min_cost) {
1646         num_threads = nt;
1647         min_cost = parallel_cost;
1648       }
1649     }
1650     return num_threads;
1651   }
1652 
1653   double computeBandwidth(bool shard_by_col, Index bm, Index bn,
1654                           Index bk) const {
1655     // Peak VFMA bandwidth is 0.5. However if we have not enough data for
1656     // vectorization bandwidth drops. The 4.0 and 2.0 bandwidth is determined
1657     // experimentally.
1658     double computeBandwidth =
1659         bk == 1 ? 4.0
1660                 : (shard_by_col ? bn : bm) < Traits::nr ||
1661                           (shard_by_col ? bm : bn) < Traits::mr
1662                       ? 2.0
1663                       : 0.5;
1664 #ifndef EIGEN_VECTORIZE_FMA
1665     // Bandwidth of all of VFMA/MULPS/ADDPS is 0.5 on latest Intel processors.
1666     // However for MULPS/ADDPS we have dependent sequence of 2 such
1667     // instructions,
1668     // so overall bandwidth is 1.0.
1669     if (computeBandwidth == 0.5) computeBandwidth = 1.0;
1670 #endif
1671     return computeBandwidth;
1672   }
1673 
1674 };
1675 
1676 } // end namespace Eigen
1677 
1678 #endif  // EIGEN_USE_THREADS
1679 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
1680