xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/sparse_matmul_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/sparse_matmul_op.h"
21 
22 #include <map>
23 #include <memory>
24 #include <vector>
25 
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/framework/bfloat16.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/kernels/fill_functor.h"
33 #include "tensorflow/core/lib/core/blocking_counter.h"
34 #include "tensorflow/core/lib/core/threadpool.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/platform/mutex.h"
39 #include "tensorflow/core/platform/thread_annotations.h"
40 #include "tensorflow/core/platform/types.h"
41 #ifdef TENSORFLOW_USE_LIBXSMM
42 #include "include/libxsmm_intrinsics_x86.h"
43 #include "include/libxsmm_malloc.h"
44 #include "include/libxsmm_spmdm.h"
45 #endif
46 
47 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
48 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
49 #endif
50 
51 #define ALWAYS_INLINE EIGEN_ALWAYS_INLINE
52 
53 namespace tensorflow {
54 namespace {
55 
56 template <typename T>
57 using BasicMatrix = Eigen::Tensor<T, 2, Eigen::RowMajor>;
58 
59 template <typename T>
60 using BasicMatrixMap =
61     Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Aligned>;
62 
63 using Matrix = BasicMatrix<float>;
64 using MatrixMap = BasicMatrixMap<float>;
65 using CPUDevice = Eigen::ThreadPoolDevice;
66 using DSizes = Eigen::DSizes<Eigen::DenseIndex, 2>;
67 
68 // Two commonly used static dsizes. We use Eigen::type2index to allow as much
69 // compile time optimization as possible.
70 inline Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>>
dsizes_00()71 dsizes_00() {
72   return Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>>();
73 }
74 inline Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>>
dsizes_10()75 dsizes_10() {
76   return Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>>();
77 }
78 
79 // Blocksizes
80 // TODO(agarwal): compute these sizes based on cache sizes.
81 const int K = 64;
82 const int M = 64;
83 const int N = 128;
84 
85 // This stores a sparse representation of a slice of a matrix with size
86 // (num_rows, num_cols). The slice is represented as a series of blocks of size
87 // (num_rows, b), where b = block_size for all but the last block, which may
88 // have fewer columns.
89 //
90 // num_rows and block_size are assumed to be <= 256. This allows storing
91 // different indices as uint8.
92 //
93 // For each block, we store all the non zero entries in data/data3 vector and
94 // the corresponding coordinates of the element in index/index3 vectors. index3
95 // vector stores index of 3 elements in the same row so that these elements can
96 // share the same row coordinate. Each entry in Index3 corresponds to 3 entries
97 // in data3.
98 //
99 // Note that all the data/indices of all the blocks are stored in the same
100 // vectors respectively. To identify block boundaries, we store the block
101 // offsets using index3_offset/index_offset. If there are n blocks in the slice,
102 // index3_offset and index_offset have n entries. The indices for the ith block
103 // are the values in the following range:
104 // [index3[index3_offset[i-1]], index3[index3_offset[i]]). Similarly for
105 // index_offset.
106 template <typename T>
107 struct SparseSlice {
108   using ConstMatrixMap = BasicMatrixMap<const T>;
109 
110  public:
111   // Indices of three elements on the same row.
112   struct Index3 {
113     uint8 m;  // row
114     // columns
115     uint8 k1;
116     uint8 k2;
117     uint8 k3;
118   };
119 
120   // Index of one element.
121   struct Index {
122     uint8 m;
123     uint8 k;
124   };
125 
SparseSlicetensorflow::__anon58b1ee950111::SparseSlice126   SparseSlice(int nrows, int ncols, int bsize)
127       : num_rows(nrows), num_cols(ncols), block_size(bsize) {
128     DCHECK_LE(nrows, 256);
129     DCHECK_LE(block_size, 256);
130   }
131 
132   // Initializes the slice with data starting at mat(0, col_offset) and with
133   // size (num_rows, num_cols).
134   // If Transpose is true, implicitly transposes mat.
135   template <bool Transpose = false>
136   void Initialize(const ConstMatrixMap& mat, int col_offset);
137 
138   void Clear();
139 
140   // See comments above.
141   std::vector<int> index3_offset;
142   std::vector<Index3> index3;
143   std::vector<T> data3;
144 
145   // See comments above. Similar to "index3" except that each element in "index"
146   // corresponds to one element in data.
147   std::vector<int> index_offset;
148   std::vector<Index> index;
149   std::vector<T> data;
150 
151   // Number of rows and columns for the slice.
152   const int num_rows;
153   const int num_cols;
154 
155   // Block size used to initialize from a matrix.
156   const int block_size;
157 };
158 
159 template <typename T>
160 bool IsZero(T v);
161 
162 template <>
IsZero(bfloat16 v)163 ALWAYS_INLINE bool IsZero(bfloat16 v) {
164   return !static_cast<bool>(v);
165 }
166 
167 template <>
IsZero(float v)168 ALWAYS_INLINE bool IsZero(float v) {
169   return v == 0.0f;
170 }
171 
172 template <typename T>
173 template <bool Transpose>
Initialize(const typename SparseSlice<T>::ConstMatrixMap & mat,int col_offset)174 void SparseSlice<T>::Initialize(
175     const typename SparseSlice<T>::ConstMatrixMap& mat, int col_offset) {
176   const int mat_rows = Transpose ? mat.dimension(1) : mat.dimension(0);
177   const int mat_cols = Transpose ? mat.dimension(0) : mat.dimension(1);
178   DCHECK_LE(num_rows, mat_rows);
179   DCHECK_LE(num_cols + col_offset, mat_cols);
180 
181   int num_blocks = (num_cols + block_size - 1) / block_size;
182   int mat_size = num_rows * num_cols;
183 
184   index3_offset.reserve(num_blocks);
185   data3.reserve(mat_size);
186   index3.reserve(mat_size / 3);
187 
188   index_offset.reserve(num_blocks);
189   data.reserve(num_blocks * num_rows * 2);
190   index.reserve(num_blocks * num_rows * 2);
191 
192   Index3 idx3;
193   const int stride = Transpose ? mat.dimension(1) : 1;
194 
195   for (int i = 0; i < num_blocks; ++i) {
196     int num_block_cols = std::min(block_size, num_cols - block_size * i);
197     for (int row = 0; row < num_rows; ++row) {
198       idx3.m = static_cast<uint8>(row);
199       // Safety note: The following code has a race, since it checks whether
200       // *curr is nonzero and then reads it again on use.  However, the result
201       // of the race is only that some of the "nonzeros" in the resulting sparse
202       // representation may actually be zero, which is harmless.
203       const auto* start =
204           Transpose ? &mat(col_offset, row) : &mat(row, col_offset);
205       const auto* curr = start;
206       const auto* end = start + stride * num_block_cols;
207       uint8 k = 0;
208 #define NEXT_ELEM \
209   curr += stride; \
210   ++k;
211 #define EAT_ZEROS                          \
212   while (curr < end && IsZero<T>(*curr)) { \
213     NEXT_ELEM;                             \
214   }
215       while (true) {
216         EAT_ZEROS
217         if (curr >= end) break;
218         idx3.k1 = k;
219         const T value1 = *curr;
220         NEXT_ELEM;
221 
222         EAT_ZEROS
223         if (curr >= end) {
224           data.push_back(value1);
225           index.push_back({idx3.m, idx3.k1});
226           break;
227         }
228         idx3.k2 = k;
229         const T value2 = *curr;
230         NEXT_ELEM;
231 
232         EAT_ZEROS
233         if (curr >= end) {
234           data.push_back(value2);
235           index.push_back({idx3.m, idx3.k2});
236           data.push_back(value1);
237           index.push_back({idx3.m, idx3.k1});
238           break;
239         }
240         idx3.k3 = k;
241         data3.push_back(value1);
242         data3.push_back(value2);
243         data3.push_back(*curr);
244         NEXT_ELEM;
245         index3.push_back(idx3);
246 #undef NEXT_ELEM
247 #undef EAT_ZEROS
248       }
249     }
250     col_offset += block_size;
251     index3_offset.push_back(index3.size());
252     index_offset.push_back(index.size());
253   }
254   DCHECK_EQ(index3_offset.size(), num_blocks);
255   DCHECK_EQ(index_offset.size(), num_blocks);
256   DCHECK_EQ(3 * index3.size(), data3.size());
257   DCHECK_EQ(index.size(), data.size());
258 }
259 
260 template <typename T>
Clear()261 void SparseSlice<T>::Clear() {
262   index3_offset.clear();
263   index3.clear();
264   data3.clear();
265   index_offset.clear();
266   index.clear();
267   data.clear();
268 }
269 
270 using Packet = Eigen::internal::packet_traits<float>::type;
271 const int kNumOperands = (sizeof(Packet) / sizeof(float));
272 #define LOAD(x) Eigen::internal::pload<Packet>(x);
273 #define EXPAND_BFLOAT_L(x, y) \
274   const auto y = Eigen::internal::pexpand_bf16_l<Packet>(x);
275 #define EXPAND_BFLOAT_U(x, y) \
276   const auto y = Eigen::internal::pexpand_bf16_u<Packet>(x);
277 #define STORE(x, y) Eigen::internal::pstore<float>(x, y);
278 #define FMA(a, b, c, d) d = Eigen::internal::pmadd<Packet>(a, b, c);
279 
ConvertBfloat16ToFloat(const bfloat16 * src)280 ALWAYS_INLINE float ConvertBfloat16ToFloat(const bfloat16* src) {
281   float out = 0;
282   auto tmp = reinterpret_cast<bfloat16*>(&out);
283 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
284   tmp[0] = *src;
285 #else
286   tmp[1] = *src;
287 #endif
288   return out;
289 }
290 
ConvertFourBfloat16ToFloat(const bfloat16 * src)291 ALWAYS_INLINE Packet ConvertFourBfloat16ToFloat(const bfloat16* src) {
292   return Eigen::internal::pload4bf16<Packet>(
293       reinterpret_cast<const float*>(src));
294 }
295 
ConvertTwoBfloat16ToFloat(const bfloat16 * src)296 ALWAYS_INLINE Packet ConvertTwoBfloat16ToFloat(const bfloat16* src) {
297   return Eigen::internal::pload2bf16<Packet>(
298       reinterpret_cast<const float*>(src));
299 }
300 
ScalarMulAdd(const float a,const float ** inp,float ** out)301 ALWAYS_INLINE void ScalarMulAdd(const float a, const float** inp, float** out) {
302   **out += a * **inp;
303   ++*inp;
304   ++*out;
305 }
306 
ScalarMulAdd(const float a,const bfloat16 ** inp,float ** out)307 ALWAYS_INLINE void ScalarMulAdd(const float a, const bfloat16** inp,
308                                 float** out) {
309   float inp_f = ConvertBfloat16ToFloat(*inp);
310   **out += a * inp_f;
311   ++*inp;
312   ++*out;
313 }
ScalarMulAdd3Way(const float a1,const float a2,const float a3,const bfloat16 ** inp1,const bfloat16 ** inp2,const bfloat16 ** inp3,float ** out)314 ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2,
315                                     const float a3, const bfloat16** inp1,
316                                     const bfloat16** inp2,
317                                     const bfloat16** inp3, float** out) {
318   float inp1_f = ConvertBfloat16ToFloat(*inp1);
319   float inp2_f = ConvertBfloat16ToFloat(*inp2);
320   float inp3_f = ConvertBfloat16ToFloat(*inp3);
321   **out += a1 * inp1_f + a2 * inp2_f + a3 * inp3_f;
322   ++*out;
323   ++*inp1;
324   ++*inp2;
325   ++*inp3;
326 }
327 
ScalarMulAdd3Way(const float a1,const float a2,const float a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)328 ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2,
329                                     const float a3, const float** inp1,
330                                     const float** inp2, const float** inp3,
331                                     float** out) {
332   **out += a1 * **inp1 + a2 * **inp2 + a3 * **inp3;
333   ++*out;
334   ++*inp1;
335   ++*inp2;
336   ++*inp3;
337 }
338 
LoadSingleScalar(const bfloat16 ** data,Packet * l)339 ALWAYS_INLINE void LoadSingleScalar(const bfloat16** data, Packet* l) {
340   auto tmp = ConvertBfloat16ToFloat(*data);
341   *l = Eigen::internal::pset1<Packet>(tmp);
342   ++*data;
343 }
344 
LoadTwoScalars(const bfloat16 ** data,Packet * l1,Packet * l2)345 ALWAYS_INLINE void LoadTwoScalars(const bfloat16** data, Packet* l1,
346                                   Packet* l2) {
347   if (kNumOperands >= 2) {
348     auto tmp = ConvertTwoBfloat16ToFloat(*data);
349     *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp);
350     *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp);
351     *data += 2;
352   } else {
353     LoadSingleScalar(data, l1);
354     LoadSingleScalar(data, l2);
355   }
356 }
357 
LoadFourScalars(const bfloat16 ** data,Packet * l1,Packet * l2,Packet * l3,Packet * l4)358 ALWAYS_INLINE void LoadFourScalars(const bfloat16** data, Packet* l1,
359                                    Packet* l2, Packet* l3, Packet* l4) {
360   if (kNumOperands >= 4) {
361     auto tmp = ConvertFourBfloat16ToFloat(*data);
362     *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp);
363     *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp);
364     *l3 = Eigen::internal::pbroadcast_third<Packet>(tmp);
365     *l4 = Eigen::internal::pbroadcast_fourth<Packet>(tmp);
366     *data += 4;
367   } else {
368     LoadTwoScalars(data, l1, l2);
369     LoadTwoScalars(data, l3, l4);
370   }
371 }
372 
LoadSingleScalar(const float ** data,Packet * l)373 ALWAYS_INLINE void LoadSingleScalar(const float** data, Packet* l) {
374   *l = Eigen::internal::pload1<Packet>(*data);
375   ++(*data);
376 }
377 
LoadTwoScalars(const float ** data,Packet * l1,Packet * l2)378 ALWAYS_INLINE void LoadTwoScalars(const float** data, Packet* l1, Packet* l2) {
379   LoadSingleScalar(data, l1);
380   LoadSingleScalar(data, l2);
381 }
382 
LoadFourScalars(const float ** data,Packet * l1,Packet * l2,Packet * l3,Packet * l4)383 ALWAYS_INLINE void LoadFourScalars(const float** data, Packet* l1, Packet* l2,
384                                    Packet* l3, Packet* l4) {
385   LoadTwoScalars(data, l1, l2);
386   LoadTwoScalars(data, l3, l4);
387 }
388 
389 template <typename T>
LoadThreeScalars(const T ** data,Packet * l1,Packet * l2,Packet * l3)390 ALWAYS_INLINE void LoadThreeScalars(const T** data, Packet* l1, Packet* l2,
391                                     Packet* l3) {
392   LoadTwoScalars(data, l1, l2);
393   LoadSingleScalar(data, l3);
394 }
395 
396 template <typename T>
LoadSixScalars(const T ** data,Packet * l1,Packet * l2,Packet * l3,Packet * l4,Packet * l5,Packet * l6)397 ALWAYS_INLINE void LoadSixScalars(const T** data, Packet* l1, Packet* l2,
398                                   Packet* l3, Packet* l4, Packet* l5,
399                                   Packet* l6) {
400   LoadFourScalars(data, l1, l2, l3, l4);
401   LoadTwoScalars(data, l5, l6);
402 }
403 
404 // Vectorized version of ScalarMulAdd.
MulAdd(const Packet a,const bfloat16 ** binp,float ** out)405 ALWAYS_INLINE void MulAdd(const Packet a, const bfloat16** binp, float** out) {
406   auto inp = reinterpret_cast<const float*>(*binp);
407   const auto b = LOAD(inp);
408   EXPAND_BFLOAT_L(b, b_0);
409   EXPAND_BFLOAT_U(b, b_1);
410   *binp += 2 * kNumOperands;
411   auto c1 = LOAD(*out);
412   auto c2 = LOAD(*out + kNumOperands);
413   FMA(a, b_0, c1, c1);
414   FMA(a, b_1, c2, c2);
415   STORE(*out, c1);
416   STORE(*out + kNumOperands, c2);
417   *out += 2 * kNumOperands;
418 }
419 
420 // Vectorized version of ScalarMulAdd3Way.
MulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const bfloat16 ** binp1,const bfloat16 ** binp2,const bfloat16 ** binp3,float ** out)421 ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3,
422                               const bfloat16** binp1, const bfloat16** binp2,
423                               const bfloat16** binp3, float** out) {
424   auto inp1 = reinterpret_cast<const float*>(*binp1);
425   auto inp2 = reinterpret_cast<const float*>(*binp2);
426   auto inp3 = reinterpret_cast<const float*>(*binp3);
427   auto c1 = LOAD(*out);
428   auto c2 = LOAD(*out + kNumOperands);
429   const auto b1 = LOAD(inp1);
430   EXPAND_BFLOAT_L(b1, b1_0);
431   EXPAND_BFLOAT_U(b1, b1_1);
432   *binp1 += 2 * kNumOperands;
433   const auto b2 = LOAD(inp2);
434   EXPAND_BFLOAT_L(b2, b2_0);
435   EXPAND_BFLOAT_U(b2, b2_1);
436   *binp2 += 2 * kNumOperands;
437   const auto b3 = LOAD(inp3);
438   EXPAND_BFLOAT_L(b3, b3_0);
439   EXPAND_BFLOAT_U(b3, b3_1);
440   *binp3 += 2 * kNumOperands;
441   FMA(a1, b1_0, c1, c1);
442   FMA(a1, b1_1, c2, c2);
443   FMA(a2, b2_0, c1, c1);
444   FMA(a2, b2_1, c2, c2);
445   FMA(a3, b3_0, c1, c1);
446   FMA(a3, b3_1, c2, c2);
447   STORE(*out, c1);
448   STORE(*out + kNumOperands, c2);
449   *out += 2 * kNumOperands;
450 }
451 
452 // Unroll MulAdd3Way for two iterations
TwoMulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const bfloat16 ** binp1,const bfloat16 ** binp2,const bfloat16 ** binp3,float ** out)453 ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2,
454                                  const Packet a3, const bfloat16** binp1,
455                                  const bfloat16** binp2, const bfloat16** binp3,
456                                  float** out) {
457   auto inp1 = reinterpret_cast<const float*>(*binp1);
458   auto inp2 = reinterpret_cast<const float*>(*binp2);
459   auto inp3 = reinterpret_cast<const float*>(*binp3);
460   auto c1 = LOAD(*out);
461   auto c2 = LOAD(*out + kNumOperands);
462   const auto b1 = LOAD(inp1);
463   const auto b2 = LOAD(inp2);
464   const auto b3 = LOAD(inp3);
465 
466   EXPAND_BFLOAT_L(b1, b1_0);
467   EXPAND_BFLOAT_U(b1, b1_1);
468   EXPAND_BFLOAT_L(b2, b2_0);
469   EXPAND_BFLOAT_U(b2, b2_1);
470   EXPAND_BFLOAT_L(b3, b3_0);
471   EXPAND_BFLOAT_U(b3, b3_1);
472   auto c3 = LOAD(*out + 2 * kNumOperands);
473   auto c4 = LOAD(*out + 3 * kNumOperands);
474   const auto b4 = LOAD(inp1 + kNumOperands);
475   const auto b5 = LOAD(inp2 + kNumOperands);
476   const auto b6 = LOAD(inp3 + kNumOperands);
477 
478   EXPAND_BFLOAT_L(b4, b4_0);
479   EXPAND_BFLOAT_U(b4, b4_1);
480   EXPAND_BFLOAT_L(b5, b5_0);
481   EXPAND_BFLOAT_U(b5, b5_1);
482   EXPAND_BFLOAT_L(b6, b6_0);
483   EXPAND_BFLOAT_U(b6, b6_1);
484 
485   FMA(a1, b1_0, c1, c1);
486   FMA(a1, b1_1, c2, c2);
487   FMA(a1, b4_0, c3, c3);
488   FMA(a1, b4_1, c4, c4);
489   FMA(a2, b2_0, c1, c1);
490   FMA(a2, b2_1, c2, c2);
491   FMA(a2, b5_0, c3, c3);
492   FMA(a2, b5_1, c4, c4);
493   FMA(a3, b3_0, c1, c1);
494   FMA(a3, b3_1, c2, c2);
495   FMA(a3, b6_0, c3, c3);
496   FMA(a3, b6_1, c4, c4);
497   STORE(*out, c1);
498   STORE(*out + kNumOperands, c2);
499   STORE(*out + 2 * kNumOperands, c3);
500   STORE(*out + 3 * kNumOperands, c4);
501   *out += 4 * kNumOperands;
502   *binp1 += 4 * kNumOperands;
503   *binp2 += 4 * kNumOperands;
504   *binp3 += 4 * kNumOperands;
505 }
506 
507 // Apply MulAdd3Way on 128 operands.
MulAdd3Way128(const Packet a1,const Packet a2,const Packet a3,const bfloat16 ** inp1,const bfloat16 ** inp2,const bfloat16 ** inp3,float ** out)508 ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2,
509                                  const Packet a3, const bfloat16** inp1,
510                                  const bfloat16** inp2, const bfloat16** inp3,
511                                  float** out) {
512   for (int k = 0; k < 128 / (8 * kNumOperands); ++k) {
513     TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
514     TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
515   }
516 }
517 
518 // Vectorized version of ScalarMulAdd
MulAdd(const Packet a,const float ** inp,float ** out)519 ALWAYS_INLINE void MulAdd(const Packet a, const float** inp, float** out) {
520   const auto b = LOAD(*inp);
521   *inp += kNumOperands;
522   auto c = LOAD(*out);
523   FMA(a, b, c, c);
524   STORE(*out, c);
525   *out += kNumOperands;
526 }
527 
528 // Vectorized version of ScalarMulAdd3Way
MulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)529 ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3,
530                               const float** inp1, const float** inp2,
531                               const float** inp3, float** out) {
532   auto c = LOAD(*out);
533   const auto b1 = LOAD(*inp1);
534   *inp1 += kNumOperands;
535   const auto b2 = LOAD(*inp2);
536   *inp2 += kNumOperands;
537   const auto b3 = LOAD(*inp3);
538   *inp3 += kNumOperands;
539   FMA(a1, b1, c, c);
540   FMA(a2, b2, c, c);
541   FMA(a3, b3, c, c);
542   STORE(*out, c);
543   *out += kNumOperands;
544 }
545 
546 // Unroll MulAdd3Way for two iterations
TwoMulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)547 ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2,
548                                  const Packet a3, const float** inp1,
549                                  const float** inp2, const float** inp3,
550                                  float** out) {
551   auto c1 = LOAD(*out);
552   const auto b1 = LOAD(*inp1);
553   const auto b2 = LOAD(*inp2);
554   const auto b3 = LOAD(*inp3);
555 
556   auto c2 = LOAD(*out + kNumOperands);
557   const auto b4 = LOAD(*inp1 + kNumOperands);
558   const auto b5 = LOAD(*inp2 + kNumOperands);
559   const auto b6 = LOAD(*inp3 + kNumOperands);
560 
561   FMA(a1, b1, c1, c1);
562   FMA(a1, b4, c2, c2);
563   FMA(a2, b2, c1, c1);
564   FMA(a2, b5, c2, c2);
565   FMA(a3, b3, c1, c1);
566   FMA(a3, b6, c2, c2);
567   STORE(*out, c1);
568   STORE(*out + kNumOperands, c2);
569   *out += 2 * kNumOperands;
570   *inp1 += 2 * kNumOperands;
571   *inp2 += 2 * kNumOperands;
572   *inp3 += 2 * kNumOperands;
573 }
574 
575 // Unroll MulAdd3Way for four iterations
FourMulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)576 ALWAYS_INLINE void FourMulAdd3Way(const Packet a1, const Packet a2,
577                                   const Packet a3, const float** inp1,
578                                   const float** inp2, const float** inp3,
579                                   float** out) {
580   TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
581   TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
582 }
583 
584 // Apply MulAdd3Way on 128 operands.
MulAdd3Way128(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)585 ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2,
586                                  const Packet a3, const float** inp1,
587                                  const float** inp2, const float** inp3,
588                                  float** out) {
589   if (kNumOperands == 8) {
590     FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
591     FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
592     FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
593     FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
594   } else {
595     DCHECK_LE(4 * kNumOperands, 128);
596     for (int i = 0; i < 128 / (4 * kNumOperands); ++i) {
597       MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
598       MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
599       MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
600       MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
601     }
602   }
603 }
604 // Computes product of "left_slices" with "num_cols" columns of "right", and
605 // stores the output in *"output".
606 // Note that left_slices is a list of SparseSlices, which are conceptually
607 // assumed to be concatenated along the column dimension. Also each SparseSlice
608 // is encoded as a list of blocks with upto N columns. See SparseSlice for more
609 // details.
610 template <typename TL, typename TR, int Cols>
GEPP(const std::vector<SparseSlice<TL> * > & left_slices,const Eigen::TensorMap<Eigen::Tensor<const TR,2,Eigen::RowMajor>,Eigen::Aligned> & right,const int num_cols,Matrix * output)611 inline void GEPP(
612     const std::vector<SparseSlice<TL>*>& left_slices,
613     const Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>,
614                            Eigen::Aligned>& right,
615     const int num_cols, Matrix* output) {
616   const int cols = (Cols == -1) ? num_cols : Cols;
617   DCHECK_EQ(num_cols, cols);
618   const int right_num_cols = right.dimension(1);
619   const int output_num_cols = output->dimension(1);
620   static const int kNumOperandsR = kNumOperands * sizeof(float) / sizeof(TR);
621   const int cols_mod = cols % kNumOperandsR;
622   int k_offset = 0;
623   // Pre-compute pointers for output matrix.
624   float* out_ptrs[M];
625   float* const out_start = &(*output)(0, 0);
626   for (int j = 0; j < M; ++j) {
627     out_ptrs[j] = out_start + output_num_cols * j;
628   }
629   for (const auto* left_slice : left_slices) {
630     const auto& left = *left_slice;
631     const auto* data3 = (!left.data3.empty()) ? &left.data3[0] : nullptr;
632     const auto* data = (!left.data.empty()) ? &left.data[0] : nullptr;
633     const int num_blocks = left.index3_offset.size();
634     int begin3 = 0;
635     int begin = 0;
636     for (int i = 0; i < num_blocks; ++i) {
637       // Pre-compute pointers for right matrix
638       const TR* right_ptrs[K];
639       const auto* const right_start = &right(k_offset, 0);
640       DCHECK_LT(k_offset, right.dimension(0));
641       for (int j = 0; j < K; ++j) {
642         right_ptrs[j] = right_start + right_num_cols * j;
643       }
644 
645       const int end3 = left.index3_offset[i];
646       int j = begin3;
647       // Loop unrolled for 2 iterations.
648       for (; j + 1 < end3; j += 2) {
649         Packet l1, l2, l3, nl1, nl2, nl3;
650         LoadSixScalars(&data3, &l1, &l2, &l3, &nl1, &nl2, &nl3);
651         const auto& index = left.index3[j];
652         const auto& nindex = left.index3[j + 1];
653         float* out = out_ptrs[index.m];
654         float* nout = out_ptrs[nindex.m];
655         const auto* r1 = right_ptrs[index.k1];
656         const auto* r2 = right_ptrs[index.k2];
657         const auto* r3 = right_ptrs[index.k3];
658 
659         const auto* nr1 = right_ptrs[nindex.k1];
660         const auto* nr2 = right_ptrs[nindex.k2];
661         const auto* nr3 = right_ptrs[nindex.k3];
662         if (cols == 128) {
663           MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out);
664           MulAdd3Way128(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout);
665         } else {
666           for (int n = 0; n < cols / kNumOperandsR; ++n) {
667             MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out);
668             MulAdd3Way(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout);
669           }
670 
671           const float sl1 = Eigen::internal::pfirst<Packet>(l1);
672           const float sl2 = Eigen::internal::pfirst<Packet>(l2);
673           const float sl3 = Eigen::internal::pfirst<Packet>(l3);
674           const float nsl1 = Eigen::internal::pfirst<Packet>(nl1);
675           const float nsl2 = Eigen::internal::pfirst<Packet>(nl2);
676           const float nsl3 = Eigen::internal::pfirst<Packet>(nl3);
677           for (int k = 0; k < cols_mod; ++k) {
678             ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out);
679             ScalarMulAdd3Way(nsl1, nsl2, nsl3, &nr1, &nr2, &nr3, &nout);
680           }
681         }
682       }
683       if (j < end3) {
684         Packet l1, l2, l3;
685         LoadThreeScalars(&data3, &l1, &l2, &l3);
686 
687         const auto& index = left.index3[j];
688         float* out = out_ptrs[index.m];
689         const auto* r1 = right_ptrs[index.k1];
690         const auto* r2 = right_ptrs[index.k2];
691         const auto* r3 = right_ptrs[index.k3];
692         if (cols == 128) {
693           MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out);
694         } else {
695           for (int n = 0; n < cols / kNumOperandsR; ++n) {
696             MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out);
697           }
698           const float sl1 = Eigen::internal::pfirst<Packet>(l1);
699           const float sl2 = Eigen::internal::pfirst<Packet>(l2);
700           const float sl3 = Eigen::internal::pfirst<Packet>(l3);
701           for (int k = 0; k < cols_mod; ++k) {
702             ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out);
703           }
704         }
705       }
706       begin3 = end3;
707       int end = left.index_offset[i];
708       // Loop unrolled for 4 iterations.
709       j = begin;
710       for (; j + 3 < end; j += 4) {
711         Packet l, nl, n2l, n3l;
712         LoadFourScalars(&data, &l, &nl, &n2l, &n3l);
713 
714         const auto& index = left.index[j];
715         const auto& nindex = left.index[j + 1];
716         const auto& n2index = left.index[j + 2];
717         const auto& n3index = left.index[j + 3];
718         const auto* r = right_ptrs[index.k];
719         const auto* nr = right_ptrs[nindex.k];
720         const auto* n2r = right_ptrs[n2index.k];
721         const auto* n3r = right_ptrs[n3index.k];
722         float* out = out_ptrs[index.m];
723         float* nout = out_ptrs[nindex.m];
724         float* n2out = out_ptrs[n2index.m];
725         float* n3out = out_ptrs[n3index.m];
726 
727         for (int n = 0; n < cols / kNumOperandsR; ++n) {
728           MulAdd(l, &r, &out);
729           MulAdd(nl, &nr, &nout);
730           MulAdd(n2l, &n2r, &n2out);
731           MulAdd(n3l, &n3r, &n3out);
732         }
733 
734         const float sl1 = Eigen::internal::pfirst<Packet>(l);
735         const float sl2 = Eigen::internal::pfirst<Packet>(nl);
736         const float sl3 = Eigen::internal::pfirst<Packet>(n2l);
737         const float sl4 = Eigen::internal::pfirst<Packet>(n3l);
738         for (int k = 0; k < cols_mod; ++k) {
739           ScalarMulAdd(sl1, &r, &out);
740           ScalarMulAdd(sl2, &nr, &nout);
741           ScalarMulAdd(sl3, &n2r, &n2out);
742           ScalarMulAdd(sl4, &n3r, &n3out);
743         }
744       }
745       while (j < end) {
746         Packet l;
747         LoadSingleScalar(&data, &l);
748         const auto& index = left.index[j];
749         const auto* r = right_ptrs[index.k];
750         float* out = out_ptrs[index.m];
751         for (int n = 0; n < cols / kNumOperandsR; ++n) {
752           MulAdd(l, &r, &out);
753         }
754         const float sl = Eigen::internal::pfirst<Packet>(l);
755         for (int k = 0; k < cols_mod; ++k) {
756           ScalarMulAdd(sl, &r, &out);
757         }
758         j++;
759       }
760       k_offset += left.block_size;
761       begin = end;
762     }
763   }
764 }
765 
766 #undef LOAD
767 #undef EXPAND_BFLOAT_L
768 #undef EXPAND_BFLOAT_U
769 #undef STORE
770 #undef FMA
771 
772 }  // namespace
773 
774 template <typename TL, typename TR>
775 class SparseMatMul {
776   using MatrixL = BasicMatrix<TL>;
777   using MatrixR = BasicMatrix<TR>;
778   using ConstMatrixMapL = BasicMatrixMap<const TL>;
779   using ConstMatrixMapR = BasicMatrixMap<const TR>;
780   using MatrixMapR = BasicMatrixMap<TR>;
781 
782  public:
783   // Not used; added to match interface of LibxsmmSparseMatMul
784   struct TensorInfoCache {};
785 
786   // Perform matrix multiplication of "left" and "right", and store the result
787   // in *"output".
788  public:
789   static inline void Compute(TensorInfoCache* cache,
790                              const ConstMatrixMapL& left,
791                              const ConstMatrixMapR& right, bool transpose_left,
792                              const DeviceBase::CpuWorkerThreads* thread_pool,
793                              bool transpose_output, MatrixMap* output);
794 
795  private:
796   // Computes multiplication of left and num_cols columns of right, and stores
797   // the output block in *"output" at offsets "output_row_offset" and
798   // "output_col_offset". If assign is true, assigns the value to that block,
799   // else adds the values to the existing values.
800   static inline void ComputeOutputBlock(
801       const std::vector<SparseSlice<TL>*>& left, const ConstMatrixMapR& right,
802       int num_cols, int output_row_offset, int output_col_offset, bool assign,
803       bool transpose_output, MatrixMap* output);
804 
805   // Encodes "mat" using a sparse representation and stores that in
806   // "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and
807   // "slice_num_cols", each grid element is converted into a SparseSlice and
808   // stored in mat_slices. "slice_block_size" is used to perform further column
809   // blocking of each slice.
810   static inline std::unique_ptr<BlockingCounter> CreateSparseSlices(
811       const ConstMatrixMapL& mat, bool transpose, int slice_num_rows,
812       int slice_block_size, int slice_num_cols,
813       std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
814       const DeviceBase::CpuWorkerThreads* thread_pool);
815 
816   // This function chops "mat" along column dimension into pieces with at most N
817   // columns, and concatenates the pieces one after the other in "buffer". It
818   // returns the list of the pieces in "slices". It returns a BlockingCounter
819   // which should be used to wait for the shuffle operations to complete.
820   static inline std::unique_ptr<BlockingCounter> CreateDenseSlices(
821       const ConstMatrixMapR& mat, int row_start, int num_rows, int col_start,
822       int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool,
823       MatrixR* buffer, std::vector<ConstMatrixMapR*>* slices);
824 
825   // Helper function for CreateDenseSlices to move the data around. It returns a
826   // BlockingCounter which should be used to wait for the shuffle operations to
827   // complete.
828   static inline BlockingCounter* ShuffleMatrix(
829       const ConstMatrixMapR& mat, int slice_row_start, int slice_num_rows,
830       int slice_col_start, int slice_num_cols, const int N,
831       const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer);
832 
833   // Helper function for CreateDenseSlices to create slices.
834   static inline void SliceMatrix(const MatrixR& mat, const int num_rows,
835                                  const int num_slices,
836                                  std::vector<ConstMatrixMapR*>* slices);
837 
838   // Heuristics to compute various block sizes.
839   // KR, NR: block sizes for "right". We run blocking iterations that operate on
840   // matrices with at most this size.
841   // KL: grid size along the column dimension used while encoding left.
842   // IB, JB: number of left and right slices to multiply together. This is used
843   // for ordering different ComputeBlockOutput operations inside each blocking
844   // iteration so as to potentially reduce the working set size.
845   static inline void ComputeBlockSizes(const ConstMatrixMapL& left,
846                                        const ConstMatrixMapR& right,
847                                        bool transpose_left, int num_threads,
848                                        int* KR, int* NR, int* KL, int* JB,
849                                        int* IB);
850 
851   TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMul);
852 };
853 
854 #ifdef TENSORFLOW_USE_LIBXSMM
855 template <typename TL, typename TR>
856 class LibxsmmSparseMatMul {
857   using MatrixL = BasicMatrix<TL>;
858   using MatrixR = BasicMatrix<TR>;
859   using ConstMatrixMapL = BasicMatrixMap<const TL>;
860   using ConstMatrixMapR = BasicMatrixMap<const TR>;
861   using MatrixMapR = BasicMatrixMap<TR>;
862 
863  public:
864   // This structure contains a set of libxsmm kernels for sizes that have been
865   // encountered previously by this operator so that libxsmm does not need to
866   // reallocate its scratchpad memory each time (which hurts performance
867   // substantially).
868   struct TensorInfoCache {
869     struct TensorInfoCacheEntry {
870       // Parameters for kernel
871       int M;
872       int K;
873       int N;
874       int max_threads;
875       // libxsmm handle and matrix data
876       libxsmm_spmdm_handle handle;
877       libxsmm_CSR_sparseslice* output_csr;
878       // Chain to non-libxsmm implementation's cache in case that ever becomes
879       // useful (it is an empty struct right now)
880       typename SparseMatMul<TL, TR>::TensorInfoCache
881           non_libxsmm_cache;  // Currently not used
882     };
883     // protects entries; invariant: entries is a valid std::multimap
884     tensorflow::mutex lock;
885     // Because there could be multiple matrix multiplies with the same sizes
886     // going on at the same time, we need to allow multiple cache entries for a
887     // given set of parameters. Taking and returning entries is used to make
888     // sure the same cache entry is not used from two threads at a time.
889     std::multimap<std::tuple<int, int, int, int>,
890                   std::unique_ptr<TensorInfoCacheEntry>>
891         entries TF_GUARDED_BY(lock);
892 
TensorInfoCachetensorflow::LibxsmmSparseMatMul::TensorInfoCache893     TensorInfoCache() : lock(), entries() {}
894     // Look up and remove first entry with these parameters, creating one if
895     // there isn't one
take_cache_entrytensorflow::LibxsmmSparseMatMul::TensorInfoCache896     std::unique_ptr<TensorInfoCacheEntry> take_cache_entry(int M, int K, int N,
897                                                            int max_threads)
898         TF_LOCKS_EXCLUDED(lock) {
899       tensorflow::mutex_lock ml(lock);
900       auto key = std::make_tuple(M, K, N, max_threads);
901       auto it = entries.find(key);
902       if (it != entries.end()) {
903         auto val = std::move(it->second);
904         entries.erase(it);
905         return val;
906       } else {
907         std::unique_ptr<TensorInfoCacheEntry> e{
908             new TensorInfoCacheEntry{M, K, N, max_threads, {}, nullptr}};
909         // setup scoped allocator, which uses cpu_allocator() for this scope
910         const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator;
911         libxsmm_spmdm_init(M, N, K, max_threads, &e->handle, &e->output_csr);
912         return e;
913       }
914     }
915     // Add a cache entry with certain parameters
return_cache_entrytensorflow::LibxsmmSparseMatMul::TensorInfoCache916     void return_cache_entry(std::unique_ptr<TensorInfoCacheEntry> e)
917         TF_LOCKS_EXCLUDED(lock) {
918       tensorflow::mutex_lock ml(lock);
919       auto key = std::make_tuple(e->M, e->K, e->N, e->max_threads);
920       entries.insert(std::make_pair(key, std::move(e)));
921     }
~TensorInfoCachetensorflow::LibxsmmSparseMatMul::TensorInfoCache922     ~TensorInfoCache() {
923       tensorflow::mutex_lock ml(lock);
924       for (auto& p : entries) {
925         libxsmm_spmdm_destroy(&p.second->handle);
926       }
927       entries.clear();
928     }
929 
930    private:
931     TF_DISALLOW_COPY_AND_ASSIGN(TensorInfoCache);
932   };
933 
934   // Perform matrix multiplication of "left" and "right", and store the result
935   // in *"output".
936  public:
937   static inline void Compute(TensorInfoCache* cache,
938                              const ConstMatrixMapL& left,
939                              const ConstMatrixMapR& right, bool transpose_left,
940                              const DeviceBase::CpuWorkerThreads* thread_pool,
941                              bool transpose_output, MatrixMap* output);
942 
943  private:
944   TF_DISALLOW_COPY_AND_ASSIGN(LibxsmmSparseMatMul);
945 };
946 #endif
947 
948 template <typename TL, typename TR,
949           template <typename TL2, typename TR2> class DoMatMul>
950 class SparseMatMulOp : public OpKernel {
951   using MatrixR = BasicMatrix<TR>;
952   using ConstMatrixMapR = BasicMatrixMap<const TR>;
953 
954  public:
SparseMatMulOp(OpKernelConstruction * ctx)955   explicit SparseMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
956     OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
957     OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
958     OP_REQUIRES_OK(ctx, ctx->GetAttr("a_is_sparse", &a_is_sparse_));
959     OP_REQUIRES_OK(ctx, ctx->GetAttr("b_is_sparse", &b_is_sparse_));
960   }
961 
Compute(OpKernelContext * ctx)962   void Compute(OpKernelContext* ctx) override {
963     const Tensor& a = ctx->input(0);
964     const Tensor& b = ctx->input(1);
965     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()),
966                 errors::InvalidArgument("a is not a matrix"));
967     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()),
968                 errors::InvalidArgument("b is not a matrix"));
969 
970     const int m = transpose_a_ ? a.dim_size(1) : a.dim_size(0);
971     const int k = transpose_a_ ? a.dim_size(0) : a.dim_size(1);
972     const int n = transpose_b_ ? b.dim_size(0) : b.dim_size(1);
973     const int k2 = transpose_b_ ? b.dim_size(1) : b.dim_size(0);
974 
975     OP_REQUIRES(ctx, k == k2,
976                 errors::InvalidArgument(
977                     "Matrix size incompatible: a: ", a.shape().DebugString(),
978                     ", b: ", b.shape().DebugString()));
979     OP_REQUIRES(ctx, m >= 0 && n >= 0 && k >= 0,
980                 errors::InvalidArgument(
981                     "Matrix dimensions cannot be negative: a: ",
982                     a.shape().DebugString(), ", b: ", b.shape().DebugString()));
983     Tensor* output = nullptr;
984     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output));
985 
986     // Return early if at least one of the output dimension size is 0.
987     if (m == 0 || n == 0) {
988       return;
989     }
990 
991     if (k == 0) {
992       // If the inner dimension k in the matrix multiplication is zero, we fill
993       // the output with zeros.
994       functor::SetZeroFunctor<CPUDevice, float> f;
995       f(ctx->eigen_device<CPUDevice>(), output->flat<float>());
996       return;
997     }
998 
999     auto out = output->matrix<float>();
1000 
1001     std::unique_ptr<Tensor> a_float;
1002     std::unique_ptr<Tensor> b_float;
1003     if (!a_is_sparse_ && !b_is_sparse_) {
1004       auto left = &a;
1005       auto right = &b;
1006       // TODO(agarwal): multi-thread the conversions from bfloat16 to float.
1007       if (std::is_same<TL, bfloat16>::value) {
1008         a_float.reset(new Tensor(DT_FLOAT, a.shape()));
1009         BFloat16ToFloat(a.flat<bfloat16>().data(),
1010                         a_float->flat<float>().data(), a.NumElements());
1011         left = a_float.get();
1012       }
1013       if (std::is_same<TR, bfloat16>::value) {
1014         b_float.reset(new Tensor(DT_FLOAT, b.shape()));
1015         BFloat16ToFloat(b.flat<bfloat16>().data(),
1016                         b_float->flat<float>().data(), b.NumElements());
1017         right = b_float.get();
1018       }
1019       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
1020       dim_pair[0].first = transpose_a_ ? 0 : 1;
1021       dim_pair[0].second = transpose_b_ ? 1 : 0;
1022 
1023       out.device(ctx->template eigen_device<CPUDevice>()) =
1024           left->matrix<float>().contract(right->matrix<float>(), dim_pair);
1025       return;
1026     }
1027 
1028     auto left = &a;
1029     auto right = &b;
1030     bool transpose_output = false;
1031     bool transpose_a = transpose_a_;
1032     bool transpose_b = transpose_b_;
1033     if (!a_is_sparse_) {
1034       // Swap the order of multiplications using the identity:
1035       // A * B = (B' *  A')'.
1036       std::swap(left, right);
1037       std::swap(transpose_a, transpose_b);
1038       transpose_a = !transpose_a;
1039       transpose_b = !transpose_b;
1040       transpose_output = !transpose_output;
1041     }
1042 
1043     std::unique_ptr<Tensor> right_tr;
1044     if (transpose_b) {
1045       // TODO(agarwal): avoid transposing the matrix here and directly handle
1046       // transpose in CreateDenseSlices.
1047       OP_REQUIRES(ctx, right->dim_size(0) != 0,
1048                   errors::InvalidArgument("b has an entry 0 in it's shape."));
1049       OP_REQUIRES(ctx, right->dim_size(1) != 0,
1050                   errors::InvalidArgument("b has an entry 0 in it's shape."));
1051       right_tr.reset(
1052           new Tensor(right->dtype(),
1053                      TensorShape({right->dim_size(1), right->dim_size(0)})));
1054 
1055       const auto perm = dsizes_10();
1056       if (transpose_output) {
1057         right_tr->matrix<TL>().device(ctx->template eigen_device<CPUDevice>()) =
1058             right->matrix<TL>().shuffle(perm);
1059       } else {
1060         right_tr->matrix<TR>().device(ctx->template eigen_device<CPUDevice>()) =
1061             right->matrix<TR>().shuffle(perm);
1062       }
1063       right = right_tr.get();
1064     }
1065 
1066     if (transpose_output) {
1067       DoMatMul<TR, TL>::Compute(&this->cache_tr_, left->matrix<TR>(),
1068                                 right->matrix<TL>(), transpose_a,
1069                                 ctx->device()->tensorflow_cpu_worker_threads(),
1070                                 transpose_output, &out);
1071     } else {
1072       DoMatMul<TL, TR>::Compute(&this->cache_nt_, left->matrix<TL>(),
1073                                 right->matrix<TR>(), transpose_a,
1074                                 ctx->device()->tensorflow_cpu_worker_threads(),
1075                                 transpose_output, &out);
1076     }
1077   }
1078 
1079  private:
1080   bool transpose_a_;
1081   bool transpose_b_;
1082   bool a_is_sparse_;
1083   bool b_is_sparse_;
1084 
1085   // Cache for non-transposed-output multiply
1086   typename DoMatMul<TL, TR>::TensorInfoCache cache_nt_;
1087   // Cache for transposed-output multiply
1088   typename DoMatMul<TR, TL>::TensorInfoCache cache_tr_;
1089 
1090   TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMulOp);
1091 };
1092 
1093 template <typename TL, typename TR>
ComputeOutputBlock(const std::vector<SparseSlice<TL> * > & left,const typename SparseMatMul<TL,TR>::ConstMatrixMapR & right,int num_cols,int output_row_offset,int output_col_offset,bool assign,bool transpose_output,MatrixMap * output)1094 inline void SparseMatMul<TL, TR>::ComputeOutputBlock(
1095     const std::vector<SparseSlice<TL>*>& left,
1096     const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right, int num_cols,
1097     int output_row_offset, int output_col_offset, bool assign,
1098     bool transpose_output, MatrixMap* output) {
1099   const auto perm = dsizes_10();
1100   int num_rows = left[0]->num_rows;
1101   const int rhs_num_cols = right.dimension(1);
1102   DCHECK_LE(num_cols, rhs_num_cols);
1103   Matrix out(num_rows, rhs_num_cols);
1104   out.setZero();
1105   if (num_cols == N) {
1106     GEPP<TL, TR, N>(left, right, num_cols, &out);
1107   } else {
1108     GEPP<TL, TR, -1>(left, right, num_cols, &out);
1109   }
1110   if (!assign) {
1111     const DSizes begin(output_row_offset, output_col_offset);
1112     const DSizes sizes(num_rows, num_cols);
1113     if (transpose_output) {
1114       if (num_cols == rhs_num_cols) {
1115         output->shuffle(perm).slice(begin, sizes) += out;
1116       } else {
1117         const auto zero = dsizes_00();
1118         output->shuffle(perm).slice(begin, sizes) += out.slice(zero, sizes);
1119       }
1120     } else {
1121       if (num_cols == rhs_num_cols) {
1122         output->slice(begin, sizes) += out;
1123       } else {
1124         const auto zero = dsizes_00();
1125         output->slice(begin, sizes) += out.slice(zero, sizes);
1126       }
1127     }
1128   } else {
1129     std::unique_ptr<Matrix> out_tr;
1130     if (transpose_output) {
1131       out_tr.reset(new Matrix(rhs_num_cols, num_rows));
1132       *out_tr = out.shuffle(perm);
1133       std::swap(output_row_offset, output_col_offset);
1134       std::swap(num_rows, num_cols);
1135     }
1136     const Matrix& final_out = transpose_output ? *out_tr : out;
1137     for (int i = 0; i < num_rows; ++i) {
1138       memcpy(&(*output)(output_row_offset + i, output_col_offset),
1139              &final_out(i, 0), num_cols * sizeof(float));
1140     }
1141   }
1142 }
1143 
1144 template <typename TL, typename TR>
1145 inline std::unique_ptr<BlockingCounter>
CreateSparseSlices(const typename SparseMatMul<TL,TR>::ConstMatrixMapL & mat,bool transpose,int slice_num_rows,int slice_block_size,int slice_num_cols,std::vector<std::vector<SparseSlice<TL> * >> * mat_slices,const DeviceBase::CpuWorkerThreads * thread_pool)1146 SparseMatMul<TL, TR>::CreateSparseSlices(
1147     const typename SparseMatMul<TL, TR>::ConstMatrixMapL& mat, bool transpose,
1148     int slice_num_rows, int slice_block_size, int slice_num_cols,
1149     std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
1150     const DeviceBase::CpuWorkerThreads* thread_pool) {
1151   const int mat_num_rows = transpose ? mat.dimension(1) : mat.dimension(0);
1152   const int mat_num_cols = transpose ? mat.dimension(0) : mat.dimension(1);
1153   const int num_slices_dim0 =
1154       std::max(1, (mat_num_rows + slice_num_rows - 1) / slice_num_rows);
1155   const int num_slices_dim1 =
1156       std::max(1, (mat_num_cols + slice_num_cols - 1) / slice_num_cols);
1157   mat_slices->resize(num_slices_dim0);
1158   BlockingCounter* counter =
1159       new BlockingCounter(num_slices_dim0 * num_slices_dim1);
1160   auto work = [counter, transpose](SparseSlice<TL>* sparse_slice,
1161                                    SparseMatMul<TL, TR>::ConstMatrixMapL* slice,
1162                                    int col_offset) {
1163     if (transpose) {
1164       sparse_slice->template Initialize<true>(*slice, col_offset);
1165     } else {
1166       sparse_slice->template Initialize<false>(*slice, col_offset);
1167     }
1168     delete slice;
1169     counter->DecrementCount();
1170   };
1171   for (int i = 0; i < num_slices_dim0; ++i) {
1172     (*mat_slices)[i].resize(num_slices_dim1);
1173     int num_rows =
1174         std::min<int>(slice_num_rows, mat_num_rows - i * slice_num_rows);
1175     for (int j = 0; j < num_slices_dim1; ++j) {
1176       int num_cols =
1177           std::min<int>(slice_num_cols, mat_num_cols - j * slice_num_cols);
1178       SparseMatMul<TL, TR>::ConstMatrixMapL* slice = nullptr;
1179       if (transpose) {
1180         slice = new SparseMatMul<TL, TR>::ConstMatrixMapL(
1181             &mat(0, i * slice_num_rows), mat.dimensions());
1182       } else {
1183         DSizes d(num_rows, mat_num_cols);
1184         slice = new SparseMatMul<TL, TR>::ConstMatrixMapL(
1185             &mat(i * slice_num_rows, 0), d);
1186       }
1187       auto* sparse_slice =
1188           new SparseSlice<TL>(num_rows, num_cols, slice_block_size);
1189       (*mat_slices)[i][j] = sparse_slice;
1190       thread_pool->workers->Schedule(
1191           [=]() { work(sparse_slice, slice, slice_num_cols * j); });
1192     }
1193   }
1194   return std::unique_ptr<BlockingCounter>(counter);
1195 }
1196 #define LOAD(x) Eigen::internal::ploadu<Packet>((x));
1197 #define INTERLEAVE(x) Eigen::internal::pinterleave4x64<Packet>(x);
1198 #define STORE(x, y) Eigen::internal::pstoreu<float>(x, y);
1199 
1200 template <int NUM_ELEM = -1>
CopyAndMayBeInterleaveBfloat16(void * bdst,const void * bsrc,int num_elements)1201 ALWAYS_INLINE void CopyAndMayBeInterleaveBfloat16(void* bdst, const void* bsrc,
1202                                                   int num_elements) {
1203   DCHECK_GE(kNumOperands, 8);
1204   static const int kStep = kNumOperands * sizeof(float) / sizeof(bfloat16);
1205   const int num = (NUM_ELEM == -1) ? num_elements : NUM_ELEM;
1206   DCHECK_EQ(num, num_elements);
1207   const float* src = reinterpret_cast<const float*>(bsrc);
1208   float* dst = reinterpret_cast<float*>(bdst);
1209   for (int index = 0; index + kStep <= num; index += kStep) {
1210     auto in = LOAD(src);
1211     auto tmp = INTERLEAVE(in);
1212     STORE(dst, tmp);
1213     src += kNumOperands;
1214     dst += kNumOperands;
1215   }
1216   if (num % kStep != 0) {
1217     memcpy(dst, src, (num % kStep) * sizeof(bfloat16));
1218   }
1219 }
1220 
1221 template <typename T>
CopyAndMayBeInterleave(void * dst,const void * src,int num_elements)1222 ALWAYS_INLINE void CopyAndMayBeInterleave(void* dst, const void* src,
1223                                           int num_elements) {
1224   if (std::is_same<T, float>::value || kNumOperands < 8) {
1225     memcpy(dst, src, num_elements * sizeof(T));
1226   } else if (std::is_same<T, bfloat16>::value) {
1227     if (num_elements == N) {
1228       CopyAndMayBeInterleaveBfloat16<N>(dst, src, num_elements);
1229     } else {
1230       CopyAndMayBeInterleaveBfloat16<-1>(dst, src, num_elements);
1231     }
1232   } else {
1233     LOG(FATAL) << "Unsupported type";
1234   }
1235 }
1236 
1237 #undef LOAD
1238 #undef Interleave
1239 #undef Store
1240 
1241 template <typename TL, typename TR>
ShuffleMatrix(const typename SparseMatMul<TL,TR>::ConstMatrixMapR & mat,int slice_row_start,int slice_num_rows,int slice_col_start,int slice_num_cols,const int N,const DeviceBase::CpuWorkerThreads * thread_pool,MatrixR * buffer)1242 inline BlockingCounter* SparseMatMul<TL, TR>::ShuffleMatrix(
1243     const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat,
1244     int slice_row_start, int slice_num_rows, int slice_col_start,
1245     int slice_num_cols, const int N,
1246     const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer) {
1247   DCHECK_EQ(N % 2, 0);
1248   DCHECK_LE(kNumOperands * sizeof(float) / sizeof(TR), N);
1249   // Note(nikhilsarda): This heuristic is optimal in benchmarks as of
1250   // Jan 21, 2020.
1251   int num_threads = std::min(thread_pool->num_threads, 8);
1252   BlockingCounter* counter = new BlockingCounter(num_threads);
1253   DCHECK_EQ(N, buffer->dimension(1));
1254   auto shuffle_work = [&mat, slice_row_start, slice_num_rows, slice_col_start,
1255                        slice_num_cols, N, buffer, counter](int s, int e) {
1256     const int row_start = s % slice_num_rows + slice_row_start;
1257     const int col_start = s / slice_num_rows * N + slice_col_start;
1258     auto* out_start = &(*buffer)(s, 0);
1259     const auto* input_start = &mat(row_start, col_start);
1260     const auto* input_end = &mat(slice_row_start + slice_num_rows - 1,
1261                                  slice_col_start + slice_num_cols - 1);
1262     const int mat_num_cols = mat.dimension(1);
1263     const int row_slice_size = slice_num_rows * mat_num_cols;
1264 
1265     const int aligned_end = slice_num_cols / N * slice_num_rows;
1266     const int e1 = std::min(e, aligned_end);
1267     while (s < e1) {
1268       CopyAndMayBeInterleave<TR>(out_start, input_start, N);
1269       out_start += N;
1270       input_start += mat_num_cols;
1271       if (input_start > input_end) {
1272         input_start = input_start - row_slice_size + N;
1273       }
1274       ++s;
1275     }
1276     int s1 = std::max(s, aligned_end);
1277     const int copy_num_cols = slice_num_cols % N;
1278     while (s1 < e) {
1279       CopyAndMayBeInterleave<TR>(out_start, input_start, copy_num_cols);
1280       out_start += N;
1281       input_start += mat_num_cols;
1282       ++s1;
1283     }
1284     if (counter) counter->DecrementCount();
1285   };
1286 
1287   int start = 0;
1288   int end = 0;
1289   int num_out_rows = (slice_num_cols + N - 1) / N * slice_num_rows;
1290   DCHECK_LE(num_out_rows, buffer->dimension(0));
1291   for (int i = std::max(1, num_threads); i > 0; --i) {
1292     end = start + num_out_rows / i;
1293     thread_pool->workers->Schedule([=]() { shuffle_work(start, end); });
1294     num_out_rows -= (end - start);
1295     start = end;
1296   }
1297   return counter;
1298 }
1299 
1300 template <typename TL, typename TR>
SliceMatrix(const MatrixR & mat,const int num_rows,const int num_slices,std::vector<typename SparseMatMul<TL,TR>::ConstMatrixMapR * > * slices)1301 inline void SparseMatMul<TL, TR>::SliceMatrix(
1302     const MatrixR& mat, const int num_rows, const int num_slices,
1303     std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
1304   slices->resize(num_slices);
1305   DSizes d(num_rows, mat.dimension(1));
1306   DCHECK_LE(num_rows * num_slices, mat.dimension(0));
1307   for (int i = 0; i < num_slices; ++i) {
1308     (*slices)[i] = new ConstMatrixMapR(&mat(i * num_rows, 0), d);
1309   }
1310 }
1311 
1312 template <typename TL, typename TR>
CreateDenseSlices(const typename SparseMatMul<TL,TR>::ConstMatrixMapR & mat,int row_start,int num_rows,int col_start,int num_cols,const DeviceBase::CpuWorkerThreads * thread_pool,MatrixR * buffer,std::vector<typename SparseMatMul<TL,TR>::ConstMatrixMapR * > * slices)1313 inline std::unique_ptr<BlockingCounter> SparseMatMul<TL, TR>::CreateDenseSlices(
1314     const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int row_start,
1315     int num_rows, int col_start, int num_cols,
1316     const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer,
1317     std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
1318   std::unique_ptr<BlockingCounter> shuffle_counter(ShuffleMatrix(
1319       mat, row_start, num_rows, col_start, num_cols, N, thread_pool, buffer));
1320   const int num_slices = (num_cols + N - 1) / N;
1321   SliceMatrix(*buffer, num_rows, num_slices, slices);
1322   return shuffle_counter;
1323 }
1324 
1325 template <typename TL, typename TR>
ComputeBlockSizes(const typename SparseMatMul<TL,TR>::ConstMatrixMapL & left,const typename SparseMatMul<TL,TR>::ConstMatrixMapR & right,bool transpose_left,int num_threads,int * KR,int * NR,int * KL,int * JB,int * IB)1326 inline void SparseMatMul<TL, TR>::ComputeBlockSizes(
1327     const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left,
1328     const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right,
1329     bool transpose_left, int num_threads, int* KR, int* NR, int* KL, int* JB,
1330     int* IB) {
1331   // Heuristics for calculating block sizes
1332   // Assume two hyperthreads per core.
1333   const int est_num_cores = std::max(1, (num_threads + 1) / 2);
1334   // Use block of rhs with at most 128K floats per core.
1335   const int mem = est_num_cores * 128 * 1024;
1336   *KR = std::min(static_cast<int>(right.dimension(0)), mem / 256);
1337   *NR = right.dimension(1);
1338   if (*KR * *NR > mem) {
1339     // 4096 may be enough to amortize the cost of writes.
1340     *KR = std::min<int>(*KR, 4096);
1341   }
1342   // Use sizes that are multiples of K and 256.
1343   *KR = std::max(1, *KR / K) * K;
1344   *NR = std::max(1, *NR / 256) * 256;
1345   if (*KR * *NR > mem) {
1346     *NR = mem / *KR;
1347   }
1348   *NR = std::max(1, *NR / 256) * 256;
1349 
1350   const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
1351   const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
1352   for (*KL = 1024; *KL > K; *KL /= 2) {
1353     if (*KR % *KL == 0 &&
1354         std::max<int>(1, left_dim0 / 64) * (left_dim1 / *KL) > est_num_cores) {
1355       break;
1356     }
1357   }
1358   DCHECK_EQ(*KL % K, 0);
1359   DCHECK_GE(*KR, *KL);
1360   if (*KR < right.dimension(0)) {
1361     CHECK_EQ(*KR % *KL, 0);
1362   }
1363 
1364   *JB = std::max(1, static_cast<int>(sqrt(num_threads) / 2.0));
1365   *IB = 8 * *JB;
1366   DCHECK_EQ(N * sizeof(float) % 64, size_t{0});
1367 }
1368 
1369 #ifdef TENSORFLOW_USE_LIBXSMM
1370 
1371 template <typename F>
do_on_all_threads(const DeviceBase::CpuWorkerThreads * thread_pool,const F & f)1372 void do_on_all_threads(const DeviceBase::CpuWorkerThreads* thread_pool,
1373                        const F& f) {
1374   int num_threads = thread_pool->num_threads;
1375   if (num_threads == 0) {
1376     LOG(FATAL) << "Have 0 threads in thread pool";
1377   } else if (num_threads == 1) {
1378     f(0);
1379   } else {
1380     BlockingCounter counter(num_threads - 1);
1381     for (int i = 1; i < num_threads; ++i) {
1382       thread_pool->workers->Schedule([&, i]() {
1383         f(i);
1384         counter.DecrementCount();
1385       });
1386     }
1387     f(0);
1388     counter.Wait();
1389   }
1390 }
1391 
1392 template <typename T>
1393 struct empty_type_wrapper {};
1394 
1395 // Copies of interface to libxsmm_spmdm_createSparseSlice_*_notrans_thread to
1396 // allow overloading
wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(empty_type_wrapper<float>,const libxsmm_spmdm_handle * handle,char transA,const float * A,libxsmm_CSR_sparseslice * libxsmm_output_csr_a,int block_id,int tid,int nthreads)1397 void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1398     empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
1399     const float* A, libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id,
1400     int tid, int nthreads) {
1401   return libxsmm_spmdm_createSparseSlice_fp32_thread(
1402       handle, transA, A, libxsmm_output_csr_a, block_id, tid, nthreads);
1403 }
wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(empty_type_wrapper<bfloat16>,const libxsmm_spmdm_handle * handle,char transA,const bfloat16 * A,libxsmm_CSR_sparseslice * libxsmm_output_csr_a,int block_id,int tid,int nthreads)1404 void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1405     empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle,
1406     char transA, const bfloat16* A,
1407     libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, int tid,
1408     int nthreads) {
1409   return libxsmm_spmdm_createSparseSlice_bfloat16_thread(
1410       handle, transA, reinterpret_cast<const libxsmm_bfloat16*>(A),
1411       libxsmm_output_csr_a, block_id, tid, nthreads);
1412 }
1413 
wrapper_libxsmm_spmdm_compute_generic_thread(empty_type_wrapper<bfloat16>,const libxsmm_spmdm_handle * handle,char transA,char transB,const bfloat16 * alpha,libxsmm_CSR_sparseslice * A_sparse,const bfloat16 * B,char transC,const bfloat16 * beta,float * C,int block_id,int tid,int nthreads)1414 void wrapper_libxsmm_spmdm_compute_generic_thread(
1415     empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle,
1416     char transA, char transB, const bfloat16* alpha,
1417     libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC,
1418     const bfloat16* beta, float* C, int block_id, int tid, int nthreads) {
1419   return libxsmm_spmdm_compute_bfloat16_thread(
1420       handle, transA, transB, reinterpret_cast<const libxsmm_bfloat16*>(alpha),
1421       A_sparse, reinterpret_cast<const libxsmm_bfloat16*>(B), transC,
1422       reinterpret_cast<const libxsmm_bfloat16*>(beta), C, block_id, tid,
1423       nthreads);
1424 }
wrapper_libxsmm_spmdm_compute_generic_thread(empty_type_wrapper<float>,const libxsmm_spmdm_handle * handle,char transA,char transB,const float * alpha,libxsmm_CSR_sparseslice * A_sparse,const float * B,char transC,const float * beta,float * C,int block_id,int tid,int nthreads)1425 void wrapper_libxsmm_spmdm_compute_generic_thread(
1426     empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
1427     char transB, const float* alpha, libxsmm_CSR_sparseslice* A_sparse,
1428     const float* B, char transC, const float* beta, float* C, int block_id,
1429     int tid, int nthreads) {
1430   return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, alpha,
1431                                            A_sparse, B, transC, beta, C,
1432                                            block_id, tid, nthreads);
1433 }
1434 
1435 template <typename TL, typename TR>
Compute(typename LibxsmmSparseMatMul<TL,TR>::TensorInfoCache * cache,const typename LibxsmmSparseMatMul<TL,TR>::ConstMatrixMapL & left,const typename LibxsmmSparseMatMul<TL,TR>::ConstMatrixMapR & right,bool transpose_left,const DeviceBase::CpuWorkerThreads * thread_pool,bool transpose_output,MatrixMap * output)1436 inline void LibxsmmSparseMatMul<TL, TR>::Compute(
1437     typename LibxsmmSparseMatMul<TL, TR>::TensorInfoCache* cache,
1438     const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapL& left,
1439     const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapR& right,
1440     bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
1441     bool transpose_output, MatrixMap* output) {
1442   const int num_threads = thread_pool->num_threads;
1443   const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
1444   const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
1445   const int right_dim0 = right.dimension(0);
1446   const int right_dim1 = right.dimension(1);
1447   CHECK_EQ(left_dim1, right_dim0);
1448   CHECK_EQ(left_dim0,
1449            (transpose_output ? output->dimension(1) : output->dimension(0)));
1450   CHECK_EQ(right_dim1,
1451            (transpose_output ? output->dimension(0) : output->dimension(1)));
1452 #if 0  // this issue seems to be resolved
1453   if (left_dim0 < 32 || left_dim1 < 32 || right_dim1 < 32) {
1454     // Causes problems in libxsmm
1455     SparseMatMul<TL, TR>::Compute(
1456         nullptr /* Assumes no cached data for fallback */, left, right,
1457         transpose_left, thread_pool, transpose_output, output);
1458     return;
1459   }
1460 #endif
1461   auto left_data = left.data();
1462   auto right_data = right.data();
1463   auto output_data = output->data();
1464   // Initialize libxsmm for this matrix; make sure another thread doesn't use
1465   // this handle
1466   auto entry =
1467       cache->take_cache_entry(left_dim0, right_dim0, right_dim1, num_threads);
1468   // Convert the left matrix to compressed sparse row (CSR) format
1469   ptrdiff_t total_num_creation_blocks =
1470       libxsmm_spmdm_get_num_createSparseSlice_blocks(&entry->handle);
1471   std::atomic<int> cur_create_block_number;
1472   cur_create_block_number.store(0);
1473   do_on_all_threads(thread_pool, [&](int i) {
1474     while (true) {
1475       int work_item = cur_create_block_number.fetch_add(1);
1476       if (work_item >= total_num_creation_blocks) break;
1477       wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1478           empty_type_wrapper<TL>{}, &entry->handle,
1479           (transpose_left ? 'T' : 'N'), left_data, entry->output_csr, work_item,
1480           i, num_threads);
1481     }
1482   });
1483   // Do matrix-matrix multiplication
1484   ptrdiff_t total_num_mult_blocks =
1485       libxsmm_spmdm_get_num_compute_blocks(&entry->handle);
1486   std::atomic<int> cur_mult_block_number;
1487   cur_mult_block_number.store(0);
1488   do_on_all_threads(thread_pool, [&](int i) {
1489     while (true) {
1490       int work_item = cur_mult_block_number.fetch_add(1);
1491       if (work_item >= total_num_mult_blocks) break;
1492       const TL alpha(1.0);  // Stored in a variable so we can get a pointer
1493       const TL beta(0.0);   // Stored in a variable so we can get a pointer
1494       wrapper_libxsmm_spmdm_compute_generic_thread(
1495           empty_type_wrapper<TL>{}, &entry->handle,
1496           (transpose_left ? 'T' : 'N'), 'N', &alpha, entry->output_csr,
1497           right_data, (transpose_output ? 'T' : 'N'), &beta, output_data,
1498           work_item, i, num_threads);
1499     }
1500   });
1501   // Put handle + CSR storage back into cache
1502   cache->return_cache_entry(std::move(entry));
1503 }
1504 
1505 #endif  // TENSORFLOW_USE_LIBXSMM
1506 
1507 // Here is an overview of the SparseMatMul code. Note that we assume that the
1508 // left matrix is sparse.
1509 //
1510 // The matrix "left" is divided into a grid with blocksize of (M, KL). Each
1511 // block is encoded as a SparseSlice. These grid elements are stored as
1512 // std::vector<std::vector<SparseSlice>>. Each element of the outer vector
1513 // represents M rows of the left matrix. Lets call these elements l_i and lets
1514 // call each element of the inner vector L_mk.
1515 //
1516 // The matrix "right" is divided into a grid with block size KR * NR.  Lets
1517 // denote the blocks on the right as R_kn. Note that we ensure that KL divides
1518 // KR so that for each element R_kn, we don't need to multiply it with any
1519 // partial L_mk blocks.
1520 //
1521 // We then multiply each right side block R_kn with the full "left" matrix and
1522 // update the output. These iterations are run sequentially since R_kn are
1523 // packed into the same underlying temporary buffer.
1524 //
1525 // In each iteration we do the following:
1526 // 1. Create slices r_j of R_kn: We split R_kn into vertical blocks with N
1527 //    (=128) columns and then concatenating these slices into a buffer. This is
1528 //    done so that each slice r_j of R_kn is stored contiguously in memory. Note
1529 //    that if R_kj has dimensions (KR, NR), we create NR / N slices, and the
1530 //    buffer has dimensions (KR * NR / N, N) (assuming N divides NR).
1531 // 2. For each (l_i, r_j), we compute the inner product using the GEPP function
1532 //    and update the output block o_ij. These calls are further blocked to
1533 //    reduce the working set size. In each iteration we take IB elements from
1534 //    {l_i} and JB elements from {r_j} and compute the IB * JB inner products.
1535 template <typename TL, typename TR>
Compute(typename SparseMatMul<TL,TR>::TensorInfoCache *,const typename SparseMatMul<TL,TR>::ConstMatrixMapL & left,const typename SparseMatMul<TL,TR>::ConstMatrixMapR & right,bool transpose_left,const DeviceBase::CpuWorkerThreads * thread_pool,bool transpose_output,MatrixMap * output)1536 inline void SparseMatMul<TL, TR>::Compute(
1537     typename SparseMatMul<TL, TR>::TensorInfoCache* /*cache*/,
1538     const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left,
1539     const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right,
1540     bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
1541     bool transpose_output, MatrixMap* output) {
1542   const int num_threads = thread_pool->num_threads;
1543   int KR, NR, KL, JB, IB;
1544   ComputeBlockSizes(left, right, transpose_left, num_threads, &KR, &NR, &KL,
1545                     &JB, &IB);
1546   // Slice the left matrix
1547   std::vector<std::vector<SparseSlice<TL>*>> left_slices;
1548   std::unique_ptr<BlockingCounter> sparse_slice_counter =
1549       CreateSparseSlices(ConstMatrixMapL(left.data(), left.dimensions()),
1550                          transpose_left, M, K, KL, &left_slices, thread_pool);
1551   const int num_left_slices = left_slices.size();
1552 
1553   const int right_dim0 = right.dimension(0);
1554   const int right_dim1 = right.dimension(1);
1555   // Allocate buffer for storing slices of right matrix.
1556   // Note buffer needs enough space to hold at most a KR * NR matrix since that
1557   // is the block size per iteration.
1558   const int buffer_num_rows =
1559       std::min(KR, right_dim0) * ((std::min(NR, right_dim1) + N - 1) / N);
1560   MatrixR buffer(buffer_num_rows, N);
1561   std::vector<ConstMatrixMapR*> right_slices;
1562 
1563   std::vector<SparseSlice<TL>*> block_left_slices;
1564   std::vector<std::function<void(void)>> tasks;
1565   // Number of blocks based on block sizes of KR * NR.
1566   const int num_k_blocks = (right_dim0 + KR - 1) / KR;
1567   const int num_n_blocks = (right_dim1 + NR - 1) / NR;
1568   std::unique_ptr<BlockingCounter> dense_slice_counter;
1569 
1570   for (int nb = 0; nb < num_n_blocks; ++nb) {
1571     const int right_num_cols =
1572         std::min(NR, static_cast<int>(right_dim1 - NR * nb));
1573     for (int kb = 0; kb < num_k_blocks; ++kb) {
1574       const int right_num_rows =
1575           std::min(KR, static_cast<int>(right_dim0 - KR * kb));
1576       dense_slice_counter = CreateDenseSlices(
1577           right, kb * KR, right_num_rows, nb * NR, right_num_cols, thread_pool,
1578           &buffer, &right_slices);
1579       const int num_right_slices = right_slices.size();
1580       tasks.reserve(num_left_slices * num_right_slices);
1581       for (int j_outer = 0; j_outer < num_right_slices; j_outer += JB) {
1582         for (int i_outer = 0; i_outer < num_left_slices; i_outer += IB) {
1583           for (int j_inner = j_outer;
1584                j_inner < std::min(num_right_slices, j_outer + JB); ++j_inner) {
1585             const int num_cols = std::min(N, right_num_cols - N * j_inner);
1586             for (int i_inner = i_outer;
1587                  i_inner < std::min(num_left_slices, i_outer + IB); ++i_inner) {
1588               block_left_slices.clear();
1589               int begin = kb * KR / KL;
1590               int end = std::min<int>((kb + 1) * KR / KL,
1591                                       (right.dimension(0) + KL - 1) / KL);
1592               DCHECK_LT(begin, end);
1593               block_left_slices.insert(block_left_slices.begin(),
1594                                        left_slices[i_inner].begin() + begin,
1595                                        left_slices[i_inner].begin() + end);
1596               tasks.push_back(std::bind(
1597                   &ComputeOutputBlock, block_left_slices,
1598                   std::ref(*right_slices[j_inner]), num_cols, M * i_inner,
1599                   N * j_inner + nb * NR, kb == 0, transpose_output, output));
1600             }
1601           }
1602         }
1603       }
1604       if (sparse_slice_counter) {
1605         sparse_slice_counter->Wait();
1606         sparse_slice_counter.reset(nullptr);
1607       }
1608       if (dense_slice_counter) {
1609         dense_slice_counter->Wait();
1610         dense_slice_counter.reset(nullptr);
1611       }
1612       BlockingCounter bc(tasks.size());
1613       for (const auto& t : tasks) {
1614         thread_pool->workers->Schedule([&bc, &t]() {
1615           t();
1616           bc.DecrementCount();
1617         });
1618       }
1619       bc.Wait();
1620       tasks.clear();
1621       for (auto& temp : right_slices) {
1622         delete temp;
1623       }
1624       right_slices.clear();
1625     }
1626   }
1627   for (auto& left_slice : left_slices) {
1628     for (auto& temp : left_slice) {
1629       delete temp;
1630     }
1631     left_slice.clear();
1632   }
1633 }
1634 
1635 #define REGISTER_SPARSE_MATMUL(TA, TB)                   \
1636   REGISTER_KERNEL_BUILDER(Name("SparseMatMul")           \
1637                               .Device(DEVICE_CPU)        \
1638                               .TypeConstraint<TA>("Ta")  \
1639                               .TypeConstraint<TB>("Tb"), \
1640                           SparseMatMulOp<TA, TB, SparseMatMul>);
1641 #ifdef TENSORFLOW_USE_LIBXSMM
1642 #define REGISTER_SPARSE_MATMUL_LIBXSMM(TA, TB)           \
1643   REGISTER_KERNEL_BUILDER(Name("SparseMatMul")           \
1644                               .Device(DEVICE_CPU)        \
1645                               .TypeConstraint<TA>("Ta")  \
1646                               .TypeConstraint<TB>("Tb"), \
1647                           SparseMatMulOp<TA, TB, LibxsmmSparseMatMul>);
1648 #endif
1649 
1650 REGISTER_SPARSE_MATMUL(float, bfloat16);
1651 REGISTER_SPARSE_MATMUL(bfloat16, float);
1652 
1653 #ifdef TENSORFLOW_USE_LIBXSMM
1654 REGISTER_SPARSE_MATMUL_LIBXSMM(bfloat16, bfloat16);
1655 REGISTER_SPARSE_MATMUL_LIBXSMM(float, float);
1656 #else
1657 REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
1658 REGISTER_SPARSE_MATMUL(float, float);
1659 #endif
1660 
1661 #undef REGISTER_SPARSE_MATMUL
1662 
1663 }  // end namespace tensorflow
1664