1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/math_ops.cc.
17
18 #define EIGEN_USE_THREADS
19
20 #include "tensorflow/core/kernels/sparse_matmul_op.h"
21
22 #include <map>
23 #include <memory>
24 #include <vector>
25
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/framework/bfloat16.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/kernels/fill_functor.h"
33 #include "tensorflow/core/lib/core/blocking_counter.h"
34 #include "tensorflow/core/lib/core/threadpool.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/platform/mutex.h"
39 #include "tensorflow/core/platform/thread_annotations.h"
40 #include "tensorflow/core/platform/types.h"
41 #ifdef TENSORFLOW_USE_LIBXSMM
42 #include "include/libxsmm_intrinsics_x86.h"
43 #include "include/libxsmm_malloc.h"
44 #include "include/libxsmm_spmdm.h"
45 #endif
46
47 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
48 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
49 #endif
50
51 #define ALWAYS_INLINE EIGEN_ALWAYS_INLINE
52
53 namespace tensorflow {
54 namespace {
55
56 template <typename T>
57 using BasicMatrix = Eigen::Tensor<T, 2, Eigen::RowMajor>;
58
59 template <typename T>
60 using BasicMatrixMap =
61 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Aligned>;
62
63 using Matrix = BasicMatrix<float>;
64 using MatrixMap = BasicMatrixMap<float>;
65 using CPUDevice = Eigen::ThreadPoolDevice;
66 using DSizes = Eigen::DSizes<Eigen::DenseIndex, 2>;
67
68 // Two commonly used static dsizes. We use Eigen::type2index to allow as much
69 // compile time optimization as possible.
70 inline Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>>
dsizes_00()71 dsizes_00() {
72 return Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>>();
73 }
74 inline Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>>
dsizes_10()75 dsizes_10() {
76 return Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>>();
77 }
78
79 // Blocksizes
80 // TODO(agarwal): compute these sizes based on cache sizes.
81 const int K = 64;
82 const int M = 64;
83 const int N = 128;
84
85 // This stores a sparse representation of a slice of a matrix with size
86 // (num_rows, num_cols). The slice is represented as a series of blocks of size
87 // (num_rows, b), where b = block_size for all but the last block, which may
88 // have fewer columns.
89 //
90 // num_rows and block_size are assumed to be <= 256. This allows storing
91 // different indices as uint8.
92 //
93 // For each block, we store all the non zero entries in data/data3 vector and
94 // the corresponding coordinates of the element in index/index3 vectors. index3
95 // vector stores index of 3 elements in the same row so that these elements can
96 // share the same row coordinate. Each entry in Index3 corresponds to 3 entries
97 // in data3.
98 //
99 // Note that all the data/indices of all the blocks are stored in the same
100 // vectors respectively. To identify block boundaries, we store the block
101 // offsets using index3_offset/index_offset. If there are n blocks in the slice,
102 // index3_offset and index_offset have n entries. The indices for the ith block
103 // are the values in the following range:
104 // [index3[index3_offset[i-1]], index3[index3_offset[i]]). Similarly for
105 // index_offset.
106 template <typename T>
107 struct SparseSlice {
108 using ConstMatrixMap = BasicMatrixMap<const T>;
109
110 public:
111 // Indices of three elements on the same row.
112 struct Index3 {
113 uint8 m; // row
114 // columns
115 uint8 k1;
116 uint8 k2;
117 uint8 k3;
118 };
119
120 // Index of one element.
121 struct Index {
122 uint8 m;
123 uint8 k;
124 };
125
SparseSlicetensorflow::__anon58b1ee950111::SparseSlice126 SparseSlice(int nrows, int ncols, int bsize)
127 : num_rows(nrows), num_cols(ncols), block_size(bsize) {
128 DCHECK_LE(nrows, 256);
129 DCHECK_LE(block_size, 256);
130 }
131
132 // Initializes the slice with data starting at mat(0, col_offset) and with
133 // size (num_rows, num_cols).
134 // If Transpose is true, implicitly transposes mat.
135 template <bool Transpose = false>
136 void Initialize(const ConstMatrixMap& mat, int col_offset);
137
138 void Clear();
139
140 // See comments above.
141 std::vector<int> index3_offset;
142 std::vector<Index3> index3;
143 std::vector<T> data3;
144
145 // See comments above. Similar to "index3" except that each element in "index"
146 // corresponds to one element in data.
147 std::vector<int> index_offset;
148 std::vector<Index> index;
149 std::vector<T> data;
150
151 // Number of rows and columns for the slice.
152 const int num_rows;
153 const int num_cols;
154
155 // Block size used to initialize from a matrix.
156 const int block_size;
157 };
158
159 template <typename T>
160 bool IsZero(T v);
161
162 template <>
IsZero(bfloat16 v)163 ALWAYS_INLINE bool IsZero(bfloat16 v) {
164 return !static_cast<bool>(v);
165 }
166
167 template <>
IsZero(float v)168 ALWAYS_INLINE bool IsZero(float v) {
169 return v == 0.0f;
170 }
171
172 template <typename T>
173 template <bool Transpose>
Initialize(const typename SparseSlice<T>::ConstMatrixMap & mat,int col_offset)174 void SparseSlice<T>::Initialize(
175 const typename SparseSlice<T>::ConstMatrixMap& mat, int col_offset) {
176 const int mat_rows = Transpose ? mat.dimension(1) : mat.dimension(0);
177 const int mat_cols = Transpose ? mat.dimension(0) : mat.dimension(1);
178 DCHECK_LE(num_rows, mat_rows);
179 DCHECK_LE(num_cols + col_offset, mat_cols);
180
181 int num_blocks = (num_cols + block_size - 1) / block_size;
182 int mat_size = num_rows * num_cols;
183
184 index3_offset.reserve(num_blocks);
185 data3.reserve(mat_size);
186 index3.reserve(mat_size / 3);
187
188 index_offset.reserve(num_blocks);
189 data.reserve(num_blocks * num_rows * 2);
190 index.reserve(num_blocks * num_rows * 2);
191
192 Index3 idx3;
193 const int stride = Transpose ? mat.dimension(1) : 1;
194
195 for (int i = 0; i < num_blocks; ++i) {
196 int num_block_cols = std::min(block_size, num_cols - block_size * i);
197 for (int row = 0; row < num_rows; ++row) {
198 idx3.m = static_cast<uint8>(row);
199 // Safety note: The following code has a race, since it checks whether
200 // *curr is nonzero and then reads it again on use. However, the result
201 // of the race is only that some of the "nonzeros" in the resulting sparse
202 // representation may actually be zero, which is harmless.
203 const auto* start =
204 Transpose ? &mat(col_offset, row) : &mat(row, col_offset);
205 const auto* curr = start;
206 const auto* end = start + stride * num_block_cols;
207 uint8 k = 0;
208 #define NEXT_ELEM \
209 curr += stride; \
210 ++k;
211 #define EAT_ZEROS \
212 while (curr < end && IsZero<T>(*curr)) { \
213 NEXT_ELEM; \
214 }
215 while (true) {
216 EAT_ZEROS
217 if (curr >= end) break;
218 idx3.k1 = k;
219 const T value1 = *curr;
220 NEXT_ELEM;
221
222 EAT_ZEROS
223 if (curr >= end) {
224 data.push_back(value1);
225 index.push_back({idx3.m, idx3.k1});
226 break;
227 }
228 idx3.k2 = k;
229 const T value2 = *curr;
230 NEXT_ELEM;
231
232 EAT_ZEROS
233 if (curr >= end) {
234 data.push_back(value2);
235 index.push_back({idx3.m, idx3.k2});
236 data.push_back(value1);
237 index.push_back({idx3.m, idx3.k1});
238 break;
239 }
240 idx3.k3 = k;
241 data3.push_back(value1);
242 data3.push_back(value2);
243 data3.push_back(*curr);
244 NEXT_ELEM;
245 index3.push_back(idx3);
246 #undef NEXT_ELEM
247 #undef EAT_ZEROS
248 }
249 }
250 col_offset += block_size;
251 index3_offset.push_back(index3.size());
252 index_offset.push_back(index.size());
253 }
254 DCHECK_EQ(index3_offset.size(), num_blocks);
255 DCHECK_EQ(index_offset.size(), num_blocks);
256 DCHECK_EQ(3 * index3.size(), data3.size());
257 DCHECK_EQ(index.size(), data.size());
258 }
259
260 template <typename T>
Clear()261 void SparseSlice<T>::Clear() {
262 index3_offset.clear();
263 index3.clear();
264 data3.clear();
265 index_offset.clear();
266 index.clear();
267 data.clear();
268 }
269
270 using Packet = Eigen::internal::packet_traits<float>::type;
271 const int kNumOperands = (sizeof(Packet) / sizeof(float));
272 #define LOAD(x) Eigen::internal::pload<Packet>(x);
273 #define EXPAND_BFLOAT_L(x, y) \
274 const auto y = Eigen::internal::pexpand_bf16_l<Packet>(x);
275 #define EXPAND_BFLOAT_U(x, y) \
276 const auto y = Eigen::internal::pexpand_bf16_u<Packet>(x);
277 #define STORE(x, y) Eigen::internal::pstore<float>(x, y);
278 #define FMA(a, b, c, d) d = Eigen::internal::pmadd<Packet>(a, b, c);
279
ConvertBfloat16ToFloat(const bfloat16 * src)280 ALWAYS_INLINE float ConvertBfloat16ToFloat(const bfloat16* src) {
281 float out = 0;
282 auto tmp = reinterpret_cast<bfloat16*>(&out);
283 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
284 tmp[0] = *src;
285 #else
286 tmp[1] = *src;
287 #endif
288 return out;
289 }
290
ConvertFourBfloat16ToFloat(const bfloat16 * src)291 ALWAYS_INLINE Packet ConvertFourBfloat16ToFloat(const bfloat16* src) {
292 return Eigen::internal::pload4bf16<Packet>(
293 reinterpret_cast<const float*>(src));
294 }
295
ConvertTwoBfloat16ToFloat(const bfloat16 * src)296 ALWAYS_INLINE Packet ConvertTwoBfloat16ToFloat(const bfloat16* src) {
297 return Eigen::internal::pload2bf16<Packet>(
298 reinterpret_cast<const float*>(src));
299 }
300
ScalarMulAdd(const float a,const float ** inp,float ** out)301 ALWAYS_INLINE void ScalarMulAdd(const float a, const float** inp, float** out) {
302 **out += a * **inp;
303 ++*inp;
304 ++*out;
305 }
306
ScalarMulAdd(const float a,const bfloat16 ** inp,float ** out)307 ALWAYS_INLINE void ScalarMulAdd(const float a, const bfloat16** inp,
308 float** out) {
309 float inp_f = ConvertBfloat16ToFloat(*inp);
310 **out += a * inp_f;
311 ++*inp;
312 ++*out;
313 }
ScalarMulAdd3Way(const float a1,const float a2,const float a3,const bfloat16 ** inp1,const bfloat16 ** inp2,const bfloat16 ** inp3,float ** out)314 ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2,
315 const float a3, const bfloat16** inp1,
316 const bfloat16** inp2,
317 const bfloat16** inp3, float** out) {
318 float inp1_f = ConvertBfloat16ToFloat(*inp1);
319 float inp2_f = ConvertBfloat16ToFloat(*inp2);
320 float inp3_f = ConvertBfloat16ToFloat(*inp3);
321 **out += a1 * inp1_f + a2 * inp2_f + a3 * inp3_f;
322 ++*out;
323 ++*inp1;
324 ++*inp2;
325 ++*inp3;
326 }
327
ScalarMulAdd3Way(const float a1,const float a2,const float a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)328 ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2,
329 const float a3, const float** inp1,
330 const float** inp2, const float** inp3,
331 float** out) {
332 **out += a1 * **inp1 + a2 * **inp2 + a3 * **inp3;
333 ++*out;
334 ++*inp1;
335 ++*inp2;
336 ++*inp3;
337 }
338
LoadSingleScalar(const bfloat16 ** data,Packet * l)339 ALWAYS_INLINE void LoadSingleScalar(const bfloat16** data, Packet* l) {
340 auto tmp = ConvertBfloat16ToFloat(*data);
341 *l = Eigen::internal::pset1<Packet>(tmp);
342 ++*data;
343 }
344
LoadTwoScalars(const bfloat16 ** data,Packet * l1,Packet * l2)345 ALWAYS_INLINE void LoadTwoScalars(const bfloat16** data, Packet* l1,
346 Packet* l2) {
347 if (kNumOperands >= 2) {
348 auto tmp = ConvertTwoBfloat16ToFloat(*data);
349 *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp);
350 *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp);
351 *data += 2;
352 } else {
353 LoadSingleScalar(data, l1);
354 LoadSingleScalar(data, l2);
355 }
356 }
357
LoadFourScalars(const bfloat16 ** data,Packet * l1,Packet * l2,Packet * l3,Packet * l4)358 ALWAYS_INLINE void LoadFourScalars(const bfloat16** data, Packet* l1,
359 Packet* l2, Packet* l3, Packet* l4) {
360 if (kNumOperands >= 4) {
361 auto tmp = ConvertFourBfloat16ToFloat(*data);
362 *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp);
363 *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp);
364 *l3 = Eigen::internal::pbroadcast_third<Packet>(tmp);
365 *l4 = Eigen::internal::pbroadcast_fourth<Packet>(tmp);
366 *data += 4;
367 } else {
368 LoadTwoScalars(data, l1, l2);
369 LoadTwoScalars(data, l3, l4);
370 }
371 }
372
LoadSingleScalar(const float ** data,Packet * l)373 ALWAYS_INLINE void LoadSingleScalar(const float** data, Packet* l) {
374 *l = Eigen::internal::pload1<Packet>(*data);
375 ++(*data);
376 }
377
LoadTwoScalars(const float ** data,Packet * l1,Packet * l2)378 ALWAYS_INLINE void LoadTwoScalars(const float** data, Packet* l1, Packet* l2) {
379 LoadSingleScalar(data, l1);
380 LoadSingleScalar(data, l2);
381 }
382
LoadFourScalars(const float ** data,Packet * l1,Packet * l2,Packet * l3,Packet * l4)383 ALWAYS_INLINE void LoadFourScalars(const float** data, Packet* l1, Packet* l2,
384 Packet* l3, Packet* l4) {
385 LoadTwoScalars(data, l1, l2);
386 LoadTwoScalars(data, l3, l4);
387 }
388
389 template <typename T>
LoadThreeScalars(const T ** data,Packet * l1,Packet * l2,Packet * l3)390 ALWAYS_INLINE void LoadThreeScalars(const T** data, Packet* l1, Packet* l2,
391 Packet* l3) {
392 LoadTwoScalars(data, l1, l2);
393 LoadSingleScalar(data, l3);
394 }
395
396 template <typename T>
LoadSixScalars(const T ** data,Packet * l1,Packet * l2,Packet * l3,Packet * l4,Packet * l5,Packet * l6)397 ALWAYS_INLINE void LoadSixScalars(const T** data, Packet* l1, Packet* l2,
398 Packet* l3, Packet* l4, Packet* l5,
399 Packet* l6) {
400 LoadFourScalars(data, l1, l2, l3, l4);
401 LoadTwoScalars(data, l5, l6);
402 }
403
404 // Vectorized version of ScalarMulAdd.
MulAdd(const Packet a,const bfloat16 ** binp,float ** out)405 ALWAYS_INLINE void MulAdd(const Packet a, const bfloat16** binp, float** out) {
406 auto inp = reinterpret_cast<const float*>(*binp);
407 const auto b = LOAD(inp);
408 EXPAND_BFLOAT_L(b, b_0);
409 EXPAND_BFLOAT_U(b, b_1);
410 *binp += 2 * kNumOperands;
411 auto c1 = LOAD(*out);
412 auto c2 = LOAD(*out + kNumOperands);
413 FMA(a, b_0, c1, c1);
414 FMA(a, b_1, c2, c2);
415 STORE(*out, c1);
416 STORE(*out + kNumOperands, c2);
417 *out += 2 * kNumOperands;
418 }
419
420 // Vectorized version of ScalarMulAdd3Way.
MulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const bfloat16 ** binp1,const bfloat16 ** binp2,const bfloat16 ** binp3,float ** out)421 ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3,
422 const bfloat16** binp1, const bfloat16** binp2,
423 const bfloat16** binp3, float** out) {
424 auto inp1 = reinterpret_cast<const float*>(*binp1);
425 auto inp2 = reinterpret_cast<const float*>(*binp2);
426 auto inp3 = reinterpret_cast<const float*>(*binp3);
427 auto c1 = LOAD(*out);
428 auto c2 = LOAD(*out + kNumOperands);
429 const auto b1 = LOAD(inp1);
430 EXPAND_BFLOAT_L(b1, b1_0);
431 EXPAND_BFLOAT_U(b1, b1_1);
432 *binp1 += 2 * kNumOperands;
433 const auto b2 = LOAD(inp2);
434 EXPAND_BFLOAT_L(b2, b2_0);
435 EXPAND_BFLOAT_U(b2, b2_1);
436 *binp2 += 2 * kNumOperands;
437 const auto b3 = LOAD(inp3);
438 EXPAND_BFLOAT_L(b3, b3_0);
439 EXPAND_BFLOAT_U(b3, b3_1);
440 *binp3 += 2 * kNumOperands;
441 FMA(a1, b1_0, c1, c1);
442 FMA(a1, b1_1, c2, c2);
443 FMA(a2, b2_0, c1, c1);
444 FMA(a2, b2_1, c2, c2);
445 FMA(a3, b3_0, c1, c1);
446 FMA(a3, b3_1, c2, c2);
447 STORE(*out, c1);
448 STORE(*out + kNumOperands, c2);
449 *out += 2 * kNumOperands;
450 }
451
452 // Unroll MulAdd3Way for two iterations
TwoMulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const bfloat16 ** binp1,const bfloat16 ** binp2,const bfloat16 ** binp3,float ** out)453 ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2,
454 const Packet a3, const bfloat16** binp1,
455 const bfloat16** binp2, const bfloat16** binp3,
456 float** out) {
457 auto inp1 = reinterpret_cast<const float*>(*binp1);
458 auto inp2 = reinterpret_cast<const float*>(*binp2);
459 auto inp3 = reinterpret_cast<const float*>(*binp3);
460 auto c1 = LOAD(*out);
461 auto c2 = LOAD(*out + kNumOperands);
462 const auto b1 = LOAD(inp1);
463 const auto b2 = LOAD(inp2);
464 const auto b3 = LOAD(inp3);
465
466 EXPAND_BFLOAT_L(b1, b1_0);
467 EXPAND_BFLOAT_U(b1, b1_1);
468 EXPAND_BFLOAT_L(b2, b2_0);
469 EXPAND_BFLOAT_U(b2, b2_1);
470 EXPAND_BFLOAT_L(b3, b3_0);
471 EXPAND_BFLOAT_U(b3, b3_1);
472 auto c3 = LOAD(*out + 2 * kNumOperands);
473 auto c4 = LOAD(*out + 3 * kNumOperands);
474 const auto b4 = LOAD(inp1 + kNumOperands);
475 const auto b5 = LOAD(inp2 + kNumOperands);
476 const auto b6 = LOAD(inp3 + kNumOperands);
477
478 EXPAND_BFLOAT_L(b4, b4_0);
479 EXPAND_BFLOAT_U(b4, b4_1);
480 EXPAND_BFLOAT_L(b5, b5_0);
481 EXPAND_BFLOAT_U(b5, b5_1);
482 EXPAND_BFLOAT_L(b6, b6_0);
483 EXPAND_BFLOAT_U(b6, b6_1);
484
485 FMA(a1, b1_0, c1, c1);
486 FMA(a1, b1_1, c2, c2);
487 FMA(a1, b4_0, c3, c3);
488 FMA(a1, b4_1, c4, c4);
489 FMA(a2, b2_0, c1, c1);
490 FMA(a2, b2_1, c2, c2);
491 FMA(a2, b5_0, c3, c3);
492 FMA(a2, b5_1, c4, c4);
493 FMA(a3, b3_0, c1, c1);
494 FMA(a3, b3_1, c2, c2);
495 FMA(a3, b6_0, c3, c3);
496 FMA(a3, b6_1, c4, c4);
497 STORE(*out, c1);
498 STORE(*out + kNumOperands, c2);
499 STORE(*out + 2 * kNumOperands, c3);
500 STORE(*out + 3 * kNumOperands, c4);
501 *out += 4 * kNumOperands;
502 *binp1 += 4 * kNumOperands;
503 *binp2 += 4 * kNumOperands;
504 *binp3 += 4 * kNumOperands;
505 }
506
507 // Apply MulAdd3Way on 128 operands.
MulAdd3Way128(const Packet a1,const Packet a2,const Packet a3,const bfloat16 ** inp1,const bfloat16 ** inp2,const bfloat16 ** inp3,float ** out)508 ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2,
509 const Packet a3, const bfloat16** inp1,
510 const bfloat16** inp2, const bfloat16** inp3,
511 float** out) {
512 for (int k = 0; k < 128 / (8 * kNumOperands); ++k) {
513 TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
514 TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
515 }
516 }
517
518 // Vectorized version of ScalarMulAdd
MulAdd(const Packet a,const float ** inp,float ** out)519 ALWAYS_INLINE void MulAdd(const Packet a, const float** inp, float** out) {
520 const auto b = LOAD(*inp);
521 *inp += kNumOperands;
522 auto c = LOAD(*out);
523 FMA(a, b, c, c);
524 STORE(*out, c);
525 *out += kNumOperands;
526 }
527
528 // Vectorized version of ScalarMulAdd3Way
MulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)529 ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3,
530 const float** inp1, const float** inp2,
531 const float** inp3, float** out) {
532 auto c = LOAD(*out);
533 const auto b1 = LOAD(*inp1);
534 *inp1 += kNumOperands;
535 const auto b2 = LOAD(*inp2);
536 *inp2 += kNumOperands;
537 const auto b3 = LOAD(*inp3);
538 *inp3 += kNumOperands;
539 FMA(a1, b1, c, c);
540 FMA(a2, b2, c, c);
541 FMA(a3, b3, c, c);
542 STORE(*out, c);
543 *out += kNumOperands;
544 }
545
546 // Unroll MulAdd3Way for two iterations
TwoMulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)547 ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2,
548 const Packet a3, const float** inp1,
549 const float** inp2, const float** inp3,
550 float** out) {
551 auto c1 = LOAD(*out);
552 const auto b1 = LOAD(*inp1);
553 const auto b2 = LOAD(*inp2);
554 const auto b3 = LOAD(*inp3);
555
556 auto c2 = LOAD(*out + kNumOperands);
557 const auto b4 = LOAD(*inp1 + kNumOperands);
558 const auto b5 = LOAD(*inp2 + kNumOperands);
559 const auto b6 = LOAD(*inp3 + kNumOperands);
560
561 FMA(a1, b1, c1, c1);
562 FMA(a1, b4, c2, c2);
563 FMA(a2, b2, c1, c1);
564 FMA(a2, b5, c2, c2);
565 FMA(a3, b3, c1, c1);
566 FMA(a3, b6, c2, c2);
567 STORE(*out, c1);
568 STORE(*out + kNumOperands, c2);
569 *out += 2 * kNumOperands;
570 *inp1 += 2 * kNumOperands;
571 *inp2 += 2 * kNumOperands;
572 *inp3 += 2 * kNumOperands;
573 }
574
575 // Unroll MulAdd3Way for four iterations
FourMulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)576 ALWAYS_INLINE void FourMulAdd3Way(const Packet a1, const Packet a2,
577 const Packet a3, const float** inp1,
578 const float** inp2, const float** inp3,
579 float** out) {
580 TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
581 TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
582 }
583
584 // Apply MulAdd3Way on 128 operands.
MulAdd3Way128(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)585 ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2,
586 const Packet a3, const float** inp1,
587 const float** inp2, const float** inp3,
588 float** out) {
589 if (kNumOperands == 8) {
590 FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
591 FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
592 FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
593 FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
594 } else {
595 DCHECK_LE(4 * kNumOperands, 128);
596 for (int i = 0; i < 128 / (4 * kNumOperands); ++i) {
597 MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
598 MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
599 MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
600 MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
601 }
602 }
603 }
604 // Computes product of "left_slices" with "num_cols" columns of "right", and
605 // stores the output in *"output".
606 // Note that left_slices is a list of SparseSlices, which are conceptually
607 // assumed to be concatenated along the column dimension. Also each SparseSlice
608 // is encoded as a list of blocks with upto N columns. See SparseSlice for more
609 // details.
610 template <typename TL, typename TR, int Cols>
GEPP(const std::vector<SparseSlice<TL> * > & left_slices,const Eigen::TensorMap<Eigen::Tensor<const TR,2,Eigen::RowMajor>,Eigen::Aligned> & right,const int num_cols,Matrix * output)611 inline void GEPP(
612 const std::vector<SparseSlice<TL>*>& left_slices,
613 const Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>,
614 Eigen::Aligned>& right,
615 const int num_cols, Matrix* output) {
616 const int cols = (Cols == -1) ? num_cols : Cols;
617 DCHECK_EQ(num_cols, cols);
618 const int right_num_cols = right.dimension(1);
619 const int output_num_cols = output->dimension(1);
620 static const int kNumOperandsR = kNumOperands * sizeof(float) / sizeof(TR);
621 const int cols_mod = cols % kNumOperandsR;
622 int k_offset = 0;
623 // Pre-compute pointers for output matrix.
624 float* out_ptrs[M];
625 float* const out_start = &(*output)(0, 0);
626 for (int j = 0; j < M; ++j) {
627 out_ptrs[j] = out_start + output_num_cols * j;
628 }
629 for (const auto* left_slice : left_slices) {
630 const auto& left = *left_slice;
631 const auto* data3 = (!left.data3.empty()) ? &left.data3[0] : nullptr;
632 const auto* data = (!left.data.empty()) ? &left.data[0] : nullptr;
633 const int num_blocks = left.index3_offset.size();
634 int begin3 = 0;
635 int begin = 0;
636 for (int i = 0; i < num_blocks; ++i) {
637 // Pre-compute pointers for right matrix
638 const TR* right_ptrs[K];
639 const auto* const right_start = &right(k_offset, 0);
640 DCHECK_LT(k_offset, right.dimension(0));
641 for (int j = 0; j < K; ++j) {
642 right_ptrs[j] = right_start + right_num_cols * j;
643 }
644
645 const int end3 = left.index3_offset[i];
646 int j = begin3;
647 // Loop unrolled for 2 iterations.
648 for (; j + 1 < end3; j += 2) {
649 Packet l1, l2, l3, nl1, nl2, nl3;
650 LoadSixScalars(&data3, &l1, &l2, &l3, &nl1, &nl2, &nl3);
651 const auto& index = left.index3[j];
652 const auto& nindex = left.index3[j + 1];
653 float* out = out_ptrs[index.m];
654 float* nout = out_ptrs[nindex.m];
655 const auto* r1 = right_ptrs[index.k1];
656 const auto* r2 = right_ptrs[index.k2];
657 const auto* r3 = right_ptrs[index.k3];
658
659 const auto* nr1 = right_ptrs[nindex.k1];
660 const auto* nr2 = right_ptrs[nindex.k2];
661 const auto* nr3 = right_ptrs[nindex.k3];
662 if (cols == 128) {
663 MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out);
664 MulAdd3Way128(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout);
665 } else {
666 for (int n = 0; n < cols / kNumOperandsR; ++n) {
667 MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out);
668 MulAdd3Way(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout);
669 }
670
671 const float sl1 = Eigen::internal::pfirst<Packet>(l1);
672 const float sl2 = Eigen::internal::pfirst<Packet>(l2);
673 const float sl3 = Eigen::internal::pfirst<Packet>(l3);
674 const float nsl1 = Eigen::internal::pfirst<Packet>(nl1);
675 const float nsl2 = Eigen::internal::pfirst<Packet>(nl2);
676 const float nsl3 = Eigen::internal::pfirst<Packet>(nl3);
677 for (int k = 0; k < cols_mod; ++k) {
678 ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out);
679 ScalarMulAdd3Way(nsl1, nsl2, nsl3, &nr1, &nr2, &nr3, &nout);
680 }
681 }
682 }
683 if (j < end3) {
684 Packet l1, l2, l3;
685 LoadThreeScalars(&data3, &l1, &l2, &l3);
686
687 const auto& index = left.index3[j];
688 float* out = out_ptrs[index.m];
689 const auto* r1 = right_ptrs[index.k1];
690 const auto* r2 = right_ptrs[index.k2];
691 const auto* r3 = right_ptrs[index.k3];
692 if (cols == 128) {
693 MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out);
694 } else {
695 for (int n = 0; n < cols / kNumOperandsR; ++n) {
696 MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out);
697 }
698 const float sl1 = Eigen::internal::pfirst<Packet>(l1);
699 const float sl2 = Eigen::internal::pfirst<Packet>(l2);
700 const float sl3 = Eigen::internal::pfirst<Packet>(l3);
701 for (int k = 0; k < cols_mod; ++k) {
702 ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out);
703 }
704 }
705 }
706 begin3 = end3;
707 int end = left.index_offset[i];
708 // Loop unrolled for 4 iterations.
709 j = begin;
710 for (; j + 3 < end; j += 4) {
711 Packet l, nl, n2l, n3l;
712 LoadFourScalars(&data, &l, &nl, &n2l, &n3l);
713
714 const auto& index = left.index[j];
715 const auto& nindex = left.index[j + 1];
716 const auto& n2index = left.index[j + 2];
717 const auto& n3index = left.index[j + 3];
718 const auto* r = right_ptrs[index.k];
719 const auto* nr = right_ptrs[nindex.k];
720 const auto* n2r = right_ptrs[n2index.k];
721 const auto* n3r = right_ptrs[n3index.k];
722 float* out = out_ptrs[index.m];
723 float* nout = out_ptrs[nindex.m];
724 float* n2out = out_ptrs[n2index.m];
725 float* n3out = out_ptrs[n3index.m];
726
727 for (int n = 0; n < cols / kNumOperandsR; ++n) {
728 MulAdd(l, &r, &out);
729 MulAdd(nl, &nr, &nout);
730 MulAdd(n2l, &n2r, &n2out);
731 MulAdd(n3l, &n3r, &n3out);
732 }
733
734 const float sl1 = Eigen::internal::pfirst<Packet>(l);
735 const float sl2 = Eigen::internal::pfirst<Packet>(nl);
736 const float sl3 = Eigen::internal::pfirst<Packet>(n2l);
737 const float sl4 = Eigen::internal::pfirst<Packet>(n3l);
738 for (int k = 0; k < cols_mod; ++k) {
739 ScalarMulAdd(sl1, &r, &out);
740 ScalarMulAdd(sl2, &nr, &nout);
741 ScalarMulAdd(sl3, &n2r, &n2out);
742 ScalarMulAdd(sl4, &n3r, &n3out);
743 }
744 }
745 while (j < end) {
746 Packet l;
747 LoadSingleScalar(&data, &l);
748 const auto& index = left.index[j];
749 const auto* r = right_ptrs[index.k];
750 float* out = out_ptrs[index.m];
751 for (int n = 0; n < cols / kNumOperandsR; ++n) {
752 MulAdd(l, &r, &out);
753 }
754 const float sl = Eigen::internal::pfirst<Packet>(l);
755 for (int k = 0; k < cols_mod; ++k) {
756 ScalarMulAdd(sl, &r, &out);
757 }
758 j++;
759 }
760 k_offset += left.block_size;
761 begin = end;
762 }
763 }
764 }
765
766 #undef LOAD
767 #undef EXPAND_BFLOAT_L
768 #undef EXPAND_BFLOAT_U
769 #undef STORE
770 #undef FMA
771
772 } // namespace
773
774 template <typename TL, typename TR>
775 class SparseMatMul {
776 using MatrixL = BasicMatrix<TL>;
777 using MatrixR = BasicMatrix<TR>;
778 using ConstMatrixMapL = BasicMatrixMap<const TL>;
779 using ConstMatrixMapR = BasicMatrixMap<const TR>;
780 using MatrixMapR = BasicMatrixMap<TR>;
781
782 public:
783 // Not used; added to match interface of LibxsmmSparseMatMul
784 struct TensorInfoCache {};
785
786 // Perform matrix multiplication of "left" and "right", and store the result
787 // in *"output".
788 public:
789 static inline void Compute(TensorInfoCache* cache,
790 const ConstMatrixMapL& left,
791 const ConstMatrixMapR& right, bool transpose_left,
792 const DeviceBase::CpuWorkerThreads* thread_pool,
793 bool transpose_output, MatrixMap* output);
794
795 private:
796 // Computes multiplication of left and num_cols columns of right, and stores
797 // the output block in *"output" at offsets "output_row_offset" and
798 // "output_col_offset". If assign is true, assigns the value to that block,
799 // else adds the values to the existing values.
800 static inline void ComputeOutputBlock(
801 const std::vector<SparseSlice<TL>*>& left, const ConstMatrixMapR& right,
802 int num_cols, int output_row_offset, int output_col_offset, bool assign,
803 bool transpose_output, MatrixMap* output);
804
805 // Encodes "mat" using a sparse representation and stores that in
806 // "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and
807 // "slice_num_cols", each grid element is converted into a SparseSlice and
808 // stored in mat_slices. "slice_block_size" is used to perform further column
809 // blocking of each slice.
810 static inline std::unique_ptr<BlockingCounter> CreateSparseSlices(
811 const ConstMatrixMapL& mat, bool transpose, int slice_num_rows,
812 int slice_block_size, int slice_num_cols,
813 std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
814 const DeviceBase::CpuWorkerThreads* thread_pool);
815
816 // This function chops "mat" along column dimension into pieces with at most N
817 // columns, and concatenates the pieces one after the other in "buffer". It
818 // returns the list of the pieces in "slices". It returns a BlockingCounter
819 // which should be used to wait for the shuffle operations to complete.
820 static inline std::unique_ptr<BlockingCounter> CreateDenseSlices(
821 const ConstMatrixMapR& mat, int row_start, int num_rows, int col_start,
822 int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool,
823 MatrixR* buffer, std::vector<ConstMatrixMapR*>* slices);
824
825 // Helper function for CreateDenseSlices to move the data around. It returns a
826 // BlockingCounter which should be used to wait for the shuffle operations to
827 // complete.
828 static inline BlockingCounter* ShuffleMatrix(
829 const ConstMatrixMapR& mat, int slice_row_start, int slice_num_rows,
830 int slice_col_start, int slice_num_cols, const int N,
831 const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer);
832
833 // Helper function for CreateDenseSlices to create slices.
834 static inline void SliceMatrix(const MatrixR& mat, const int num_rows,
835 const int num_slices,
836 std::vector<ConstMatrixMapR*>* slices);
837
838 // Heuristics to compute various block sizes.
839 // KR, NR: block sizes for "right". We run blocking iterations that operate on
840 // matrices with at most this size.
841 // KL: grid size along the column dimension used while encoding left.
842 // IB, JB: number of left and right slices to multiply together. This is used
843 // for ordering different ComputeBlockOutput operations inside each blocking
844 // iteration so as to potentially reduce the working set size.
845 static inline void ComputeBlockSizes(const ConstMatrixMapL& left,
846 const ConstMatrixMapR& right,
847 bool transpose_left, int num_threads,
848 int* KR, int* NR, int* KL, int* JB,
849 int* IB);
850
851 TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMul);
852 };
853
854 #ifdef TENSORFLOW_USE_LIBXSMM
855 template <typename TL, typename TR>
856 class LibxsmmSparseMatMul {
857 using MatrixL = BasicMatrix<TL>;
858 using MatrixR = BasicMatrix<TR>;
859 using ConstMatrixMapL = BasicMatrixMap<const TL>;
860 using ConstMatrixMapR = BasicMatrixMap<const TR>;
861 using MatrixMapR = BasicMatrixMap<TR>;
862
863 public:
864 // This structure contains a set of libxsmm kernels for sizes that have been
865 // encountered previously by this operator so that libxsmm does not need to
866 // reallocate its scratchpad memory each time (which hurts performance
867 // substantially).
868 struct TensorInfoCache {
869 struct TensorInfoCacheEntry {
870 // Parameters for kernel
871 int M;
872 int K;
873 int N;
874 int max_threads;
875 // libxsmm handle and matrix data
876 libxsmm_spmdm_handle handle;
877 libxsmm_CSR_sparseslice* output_csr;
878 // Chain to non-libxsmm implementation's cache in case that ever becomes
879 // useful (it is an empty struct right now)
880 typename SparseMatMul<TL, TR>::TensorInfoCache
881 non_libxsmm_cache; // Currently not used
882 };
883 // protects entries; invariant: entries is a valid std::multimap
884 tensorflow::mutex lock;
885 // Because there could be multiple matrix multiplies with the same sizes
886 // going on at the same time, we need to allow multiple cache entries for a
887 // given set of parameters. Taking and returning entries is used to make
888 // sure the same cache entry is not used from two threads at a time.
889 std::multimap<std::tuple<int, int, int, int>,
890 std::unique_ptr<TensorInfoCacheEntry>>
891 entries TF_GUARDED_BY(lock);
892
TensorInfoCachetensorflow::LibxsmmSparseMatMul::TensorInfoCache893 TensorInfoCache() : lock(), entries() {}
894 // Look up and remove first entry with these parameters, creating one if
895 // there isn't one
take_cache_entrytensorflow::LibxsmmSparseMatMul::TensorInfoCache896 std::unique_ptr<TensorInfoCacheEntry> take_cache_entry(int M, int K, int N,
897 int max_threads)
898 TF_LOCKS_EXCLUDED(lock) {
899 tensorflow::mutex_lock ml(lock);
900 auto key = std::make_tuple(M, K, N, max_threads);
901 auto it = entries.find(key);
902 if (it != entries.end()) {
903 auto val = std::move(it->second);
904 entries.erase(it);
905 return val;
906 } else {
907 std::unique_ptr<TensorInfoCacheEntry> e{
908 new TensorInfoCacheEntry{M, K, N, max_threads, {}, nullptr}};
909 // setup scoped allocator, which uses cpu_allocator() for this scope
910 const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator;
911 libxsmm_spmdm_init(M, N, K, max_threads, &e->handle, &e->output_csr);
912 return e;
913 }
914 }
915 // Add a cache entry with certain parameters
return_cache_entrytensorflow::LibxsmmSparseMatMul::TensorInfoCache916 void return_cache_entry(std::unique_ptr<TensorInfoCacheEntry> e)
917 TF_LOCKS_EXCLUDED(lock) {
918 tensorflow::mutex_lock ml(lock);
919 auto key = std::make_tuple(e->M, e->K, e->N, e->max_threads);
920 entries.insert(std::make_pair(key, std::move(e)));
921 }
~TensorInfoCachetensorflow::LibxsmmSparseMatMul::TensorInfoCache922 ~TensorInfoCache() {
923 tensorflow::mutex_lock ml(lock);
924 for (auto& p : entries) {
925 libxsmm_spmdm_destroy(&p.second->handle);
926 }
927 entries.clear();
928 }
929
930 private:
931 TF_DISALLOW_COPY_AND_ASSIGN(TensorInfoCache);
932 };
933
934 // Perform matrix multiplication of "left" and "right", and store the result
935 // in *"output".
936 public:
937 static inline void Compute(TensorInfoCache* cache,
938 const ConstMatrixMapL& left,
939 const ConstMatrixMapR& right, bool transpose_left,
940 const DeviceBase::CpuWorkerThreads* thread_pool,
941 bool transpose_output, MatrixMap* output);
942
943 private:
944 TF_DISALLOW_COPY_AND_ASSIGN(LibxsmmSparseMatMul);
945 };
946 #endif
947
948 template <typename TL, typename TR,
949 template <typename TL2, typename TR2> class DoMatMul>
950 class SparseMatMulOp : public OpKernel {
951 using MatrixR = BasicMatrix<TR>;
952 using ConstMatrixMapR = BasicMatrixMap<const TR>;
953
954 public:
SparseMatMulOp(OpKernelConstruction * ctx)955 explicit SparseMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
956 OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
957 OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
958 OP_REQUIRES_OK(ctx, ctx->GetAttr("a_is_sparse", &a_is_sparse_));
959 OP_REQUIRES_OK(ctx, ctx->GetAttr("b_is_sparse", &b_is_sparse_));
960 }
961
Compute(OpKernelContext * ctx)962 void Compute(OpKernelContext* ctx) override {
963 const Tensor& a = ctx->input(0);
964 const Tensor& b = ctx->input(1);
965 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()),
966 errors::InvalidArgument("a is not a matrix"));
967 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()),
968 errors::InvalidArgument("b is not a matrix"));
969
970 const int m = transpose_a_ ? a.dim_size(1) : a.dim_size(0);
971 const int k = transpose_a_ ? a.dim_size(0) : a.dim_size(1);
972 const int n = transpose_b_ ? b.dim_size(0) : b.dim_size(1);
973 const int k2 = transpose_b_ ? b.dim_size(1) : b.dim_size(0);
974
975 OP_REQUIRES(ctx, k == k2,
976 errors::InvalidArgument(
977 "Matrix size incompatible: a: ", a.shape().DebugString(),
978 ", b: ", b.shape().DebugString()));
979 OP_REQUIRES(ctx, m >= 0 && n >= 0 && k >= 0,
980 errors::InvalidArgument(
981 "Matrix dimensions cannot be negative: a: ",
982 a.shape().DebugString(), ", b: ", b.shape().DebugString()));
983 Tensor* output = nullptr;
984 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output));
985
986 // Return early if at least one of the output dimension size is 0.
987 if (m == 0 || n == 0) {
988 return;
989 }
990
991 if (k == 0) {
992 // If the inner dimension k in the matrix multiplication is zero, we fill
993 // the output with zeros.
994 functor::SetZeroFunctor<CPUDevice, float> f;
995 f(ctx->eigen_device<CPUDevice>(), output->flat<float>());
996 return;
997 }
998
999 auto out = output->matrix<float>();
1000
1001 std::unique_ptr<Tensor> a_float;
1002 std::unique_ptr<Tensor> b_float;
1003 if (!a_is_sparse_ && !b_is_sparse_) {
1004 auto left = &a;
1005 auto right = &b;
1006 // TODO(agarwal): multi-thread the conversions from bfloat16 to float.
1007 if (std::is_same<TL, bfloat16>::value) {
1008 a_float.reset(new Tensor(DT_FLOAT, a.shape()));
1009 BFloat16ToFloat(a.flat<bfloat16>().data(),
1010 a_float->flat<float>().data(), a.NumElements());
1011 left = a_float.get();
1012 }
1013 if (std::is_same<TR, bfloat16>::value) {
1014 b_float.reset(new Tensor(DT_FLOAT, b.shape()));
1015 BFloat16ToFloat(b.flat<bfloat16>().data(),
1016 b_float->flat<float>().data(), b.NumElements());
1017 right = b_float.get();
1018 }
1019 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
1020 dim_pair[0].first = transpose_a_ ? 0 : 1;
1021 dim_pair[0].second = transpose_b_ ? 1 : 0;
1022
1023 out.device(ctx->template eigen_device<CPUDevice>()) =
1024 left->matrix<float>().contract(right->matrix<float>(), dim_pair);
1025 return;
1026 }
1027
1028 auto left = &a;
1029 auto right = &b;
1030 bool transpose_output = false;
1031 bool transpose_a = transpose_a_;
1032 bool transpose_b = transpose_b_;
1033 if (!a_is_sparse_) {
1034 // Swap the order of multiplications using the identity:
1035 // A * B = (B' * A')'.
1036 std::swap(left, right);
1037 std::swap(transpose_a, transpose_b);
1038 transpose_a = !transpose_a;
1039 transpose_b = !transpose_b;
1040 transpose_output = !transpose_output;
1041 }
1042
1043 std::unique_ptr<Tensor> right_tr;
1044 if (transpose_b) {
1045 // TODO(agarwal): avoid transposing the matrix here and directly handle
1046 // transpose in CreateDenseSlices.
1047 OP_REQUIRES(ctx, right->dim_size(0) != 0,
1048 errors::InvalidArgument("b has an entry 0 in it's shape."));
1049 OP_REQUIRES(ctx, right->dim_size(1) != 0,
1050 errors::InvalidArgument("b has an entry 0 in it's shape."));
1051 right_tr.reset(
1052 new Tensor(right->dtype(),
1053 TensorShape({right->dim_size(1), right->dim_size(0)})));
1054
1055 const auto perm = dsizes_10();
1056 if (transpose_output) {
1057 right_tr->matrix<TL>().device(ctx->template eigen_device<CPUDevice>()) =
1058 right->matrix<TL>().shuffle(perm);
1059 } else {
1060 right_tr->matrix<TR>().device(ctx->template eigen_device<CPUDevice>()) =
1061 right->matrix<TR>().shuffle(perm);
1062 }
1063 right = right_tr.get();
1064 }
1065
1066 if (transpose_output) {
1067 DoMatMul<TR, TL>::Compute(&this->cache_tr_, left->matrix<TR>(),
1068 right->matrix<TL>(), transpose_a,
1069 ctx->device()->tensorflow_cpu_worker_threads(),
1070 transpose_output, &out);
1071 } else {
1072 DoMatMul<TL, TR>::Compute(&this->cache_nt_, left->matrix<TL>(),
1073 right->matrix<TR>(), transpose_a,
1074 ctx->device()->tensorflow_cpu_worker_threads(),
1075 transpose_output, &out);
1076 }
1077 }
1078
1079 private:
1080 bool transpose_a_;
1081 bool transpose_b_;
1082 bool a_is_sparse_;
1083 bool b_is_sparse_;
1084
1085 // Cache for non-transposed-output multiply
1086 typename DoMatMul<TL, TR>::TensorInfoCache cache_nt_;
1087 // Cache for transposed-output multiply
1088 typename DoMatMul<TR, TL>::TensorInfoCache cache_tr_;
1089
1090 TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMulOp);
1091 };
1092
1093 template <typename TL, typename TR>
ComputeOutputBlock(const std::vector<SparseSlice<TL> * > & left,const typename SparseMatMul<TL,TR>::ConstMatrixMapR & right,int num_cols,int output_row_offset,int output_col_offset,bool assign,bool transpose_output,MatrixMap * output)1094 inline void SparseMatMul<TL, TR>::ComputeOutputBlock(
1095 const std::vector<SparseSlice<TL>*>& left,
1096 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right, int num_cols,
1097 int output_row_offset, int output_col_offset, bool assign,
1098 bool transpose_output, MatrixMap* output) {
1099 const auto perm = dsizes_10();
1100 int num_rows = left[0]->num_rows;
1101 const int rhs_num_cols = right.dimension(1);
1102 DCHECK_LE(num_cols, rhs_num_cols);
1103 Matrix out(num_rows, rhs_num_cols);
1104 out.setZero();
1105 if (num_cols == N) {
1106 GEPP<TL, TR, N>(left, right, num_cols, &out);
1107 } else {
1108 GEPP<TL, TR, -1>(left, right, num_cols, &out);
1109 }
1110 if (!assign) {
1111 const DSizes begin(output_row_offset, output_col_offset);
1112 const DSizes sizes(num_rows, num_cols);
1113 if (transpose_output) {
1114 if (num_cols == rhs_num_cols) {
1115 output->shuffle(perm).slice(begin, sizes) += out;
1116 } else {
1117 const auto zero = dsizes_00();
1118 output->shuffle(perm).slice(begin, sizes) += out.slice(zero, sizes);
1119 }
1120 } else {
1121 if (num_cols == rhs_num_cols) {
1122 output->slice(begin, sizes) += out;
1123 } else {
1124 const auto zero = dsizes_00();
1125 output->slice(begin, sizes) += out.slice(zero, sizes);
1126 }
1127 }
1128 } else {
1129 std::unique_ptr<Matrix> out_tr;
1130 if (transpose_output) {
1131 out_tr.reset(new Matrix(rhs_num_cols, num_rows));
1132 *out_tr = out.shuffle(perm);
1133 std::swap(output_row_offset, output_col_offset);
1134 std::swap(num_rows, num_cols);
1135 }
1136 const Matrix& final_out = transpose_output ? *out_tr : out;
1137 for (int i = 0; i < num_rows; ++i) {
1138 memcpy(&(*output)(output_row_offset + i, output_col_offset),
1139 &final_out(i, 0), num_cols * sizeof(float));
1140 }
1141 }
1142 }
1143
1144 template <typename TL, typename TR>
1145 inline std::unique_ptr<BlockingCounter>
CreateSparseSlices(const typename SparseMatMul<TL,TR>::ConstMatrixMapL & mat,bool transpose,int slice_num_rows,int slice_block_size,int slice_num_cols,std::vector<std::vector<SparseSlice<TL> * >> * mat_slices,const DeviceBase::CpuWorkerThreads * thread_pool)1146 SparseMatMul<TL, TR>::CreateSparseSlices(
1147 const typename SparseMatMul<TL, TR>::ConstMatrixMapL& mat, bool transpose,
1148 int slice_num_rows, int slice_block_size, int slice_num_cols,
1149 std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
1150 const DeviceBase::CpuWorkerThreads* thread_pool) {
1151 const int mat_num_rows = transpose ? mat.dimension(1) : mat.dimension(0);
1152 const int mat_num_cols = transpose ? mat.dimension(0) : mat.dimension(1);
1153 const int num_slices_dim0 =
1154 std::max(1, (mat_num_rows + slice_num_rows - 1) / slice_num_rows);
1155 const int num_slices_dim1 =
1156 std::max(1, (mat_num_cols + slice_num_cols - 1) / slice_num_cols);
1157 mat_slices->resize(num_slices_dim0);
1158 BlockingCounter* counter =
1159 new BlockingCounter(num_slices_dim0 * num_slices_dim1);
1160 auto work = [counter, transpose](SparseSlice<TL>* sparse_slice,
1161 SparseMatMul<TL, TR>::ConstMatrixMapL* slice,
1162 int col_offset) {
1163 if (transpose) {
1164 sparse_slice->template Initialize<true>(*slice, col_offset);
1165 } else {
1166 sparse_slice->template Initialize<false>(*slice, col_offset);
1167 }
1168 delete slice;
1169 counter->DecrementCount();
1170 };
1171 for (int i = 0; i < num_slices_dim0; ++i) {
1172 (*mat_slices)[i].resize(num_slices_dim1);
1173 int num_rows =
1174 std::min<int>(slice_num_rows, mat_num_rows - i * slice_num_rows);
1175 for (int j = 0; j < num_slices_dim1; ++j) {
1176 int num_cols =
1177 std::min<int>(slice_num_cols, mat_num_cols - j * slice_num_cols);
1178 SparseMatMul<TL, TR>::ConstMatrixMapL* slice = nullptr;
1179 if (transpose) {
1180 slice = new SparseMatMul<TL, TR>::ConstMatrixMapL(
1181 &mat(0, i * slice_num_rows), mat.dimensions());
1182 } else {
1183 DSizes d(num_rows, mat_num_cols);
1184 slice = new SparseMatMul<TL, TR>::ConstMatrixMapL(
1185 &mat(i * slice_num_rows, 0), d);
1186 }
1187 auto* sparse_slice =
1188 new SparseSlice<TL>(num_rows, num_cols, slice_block_size);
1189 (*mat_slices)[i][j] = sparse_slice;
1190 thread_pool->workers->Schedule(
1191 [=]() { work(sparse_slice, slice, slice_num_cols * j); });
1192 }
1193 }
1194 return std::unique_ptr<BlockingCounter>(counter);
1195 }
1196 #define LOAD(x) Eigen::internal::ploadu<Packet>((x));
1197 #define INTERLEAVE(x) Eigen::internal::pinterleave4x64<Packet>(x);
1198 #define STORE(x, y) Eigen::internal::pstoreu<float>(x, y);
1199
1200 template <int NUM_ELEM = -1>
CopyAndMayBeInterleaveBfloat16(void * bdst,const void * bsrc,int num_elements)1201 ALWAYS_INLINE void CopyAndMayBeInterleaveBfloat16(void* bdst, const void* bsrc,
1202 int num_elements) {
1203 DCHECK_GE(kNumOperands, 8);
1204 static const int kStep = kNumOperands * sizeof(float) / sizeof(bfloat16);
1205 const int num = (NUM_ELEM == -1) ? num_elements : NUM_ELEM;
1206 DCHECK_EQ(num, num_elements);
1207 const float* src = reinterpret_cast<const float*>(bsrc);
1208 float* dst = reinterpret_cast<float*>(bdst);
1209 for (int index = 0; index + kStep <= num; index += kStep) {
1210 auto in = LOAD(src);
1211 auto tmp = INTERLEAVE(in);
1212 STORE(dst, tmp);
1213 src += kNumOperands;
1214 dst += kNumOperands;
1215 }
1216 if (num % kStep != 0) {
1217 memcpy(dst, src, (num % kStep) * sizeof(bfloat16));
1218 }
1219 }
1220
1221 template <typename T>
CopyAndMayBeInterleave(void * dst,const void * src,int num_elements)1222 ALWAYS_INLINE void CopyAndMayBeInterleave(void* dst, const void* src,
1223 int num_elements) {
1224 if (std::is_same<T, float>::value || kNumOperands < 8) {
1225 memcpy(dst, src, num_elements * sizeof(T));
1226 } else if (std::is_same<T, bfloat16>::value) {
1227 if (num_elements == N) {
1228 CopyAndMayBeInterleaveBfloat16<N>(dst, src, num_elements);
1229 } else {
1230 CopyAndMayBeInterleaveBfloat16<-1>(dst, src, num_elements);
1231 }
1232 } else {
1233 LOG(FATAL) << "Unsupported type";
1234 }
1235 }
1236
1237 #undef LOAD
1238 #undef Interleave
1239 #undef Store
1240
1241 template <typename TL, typename TR>
ShuffleMatrix(const typename SparseMatMul<TL,TR>::ConstMatrixMapR & mat,int slice_row_start,int slice_num_rows,int slice_col_start,int slice_num_cols,const int N,const DeviceBase::CpuWorkerThreads * thread_pool,MatrixR * buffer)1242 inline BlockingCounter* SparseMatMul<TL, TR>::ShuffleMatrix(
1243 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat,
1244 int slice_row_start, int slice_num_rows, int slice_col_start,
1245 int slice_num_cols, const int N,
1246 const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer) {
1247 DCHECK_EQ(N % 2, 0);
1248 DCHECK_LE(kNumOperands * sizeof(float) / sizeof(TR), N);
1249 // Note(nikhilsarda): This heuristic is optimal in benchmarks as of
1250 // Jan 21, 2020.
1251 int num_threads = std::min(thread_pool->num_threads, 8);
1252 BlockingCounter* counter = new BlockingCounter(num_threads);
1253 DCHECK_EQ(N, buffer->dimension(1));
1254 auto shuffle_work = [&mat, slice_row_start, slice_num_rows, slice_col_start,
1255 slice_num_cols, N, buffer, counter](int s, int e) {
1256 const int row_start = s % slice_num_rows + slice_row_start;
1257 const int col_start = s / slice_num_rows * N + slice_col_start;
1258 auto* out_start = &(*buffer)(s, 0);
1259 const auto* input_start = &mat(row_start, col_start);
1260 const auto* input_end = &mat(slice_row_start + slice_num_rows - 1,
1261 slice_col_start + slice_num_cols - 1);
1262 const int mat_num_cols = mat.dimension(1);
1263 const int row_slice_size = slice_num_rows * mat_num_cols;
1264
1265 const int aligned_end = slice_num_cols / N * slice_num_rows;
1266 const int e1 = std::min(e, aligned_end);
1267 while (s < e1) {
1268 CopyAndMayBeInterleave<TR>(out_start, input_start, N);
1269 out_start += N;
1270 input_start += mat_num_cols;
1271 if (input_start > input_end) {
1272 input_start = input_start - row_slice_size + N;
1273 }
1274 ++s;
1275 }
1276 int s1 = std::max(s, aligned_end);
1277 const int copy_num_cols = slice_num_cols % N;
1278 while (s1 < e) {
1279 CopyAndMayBeInterleave<TR>(out_start, input_start, copy_num_cols);
1280 out_start += N;
1281 input_start += mat_num_cols;
1282 ++s1;
1283 }
1284 if (counter) counter->DecrementCount();
1285 };
1286
1287 int start = 0;
1288 int end = 0;
1289 int num_out_rows = (slice_num_cols + N - 1) / N * slice_num_rows;
1290 DCHECK_LE(num_out_rows, buffer->dimension(0));
1291 for (int i = std::max(1, num_threads); i > 0; --i) {
1292 end = start + num_out_rows / i;
1293 thread_pool->workers->Schedule([=]() { shuffle_work(start, end); });
1294 num_out_rows -= (end - start);
1295 start = end;
1296 }
1297 return counter;
1298 }
1299
1300 template <typename TL, typename TR>
SliceMatrix(const MatrixR & mat,const int num_rows,const int num_slices,std::vector<typename SparseMatMul<TL,TR>::ConstMatrixMapR * > * slices)1301 inline void SparseMatMul<TL, TR>::SliceMatrix(
1302 const MatrixR& mat, const int num_rows, const int num_slices,
1303 std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
1304 slices->resize(num_slices);
1305 DSizes d(num_rows, mat.dimension(1));
1306 DCHECK_LE(num_rows * num_slices, mat.dimension(0));
1307 for (int i = 0; i < num_slices; ++i) {
1308 (*slices)[i] = new ConstMatrixMapR(&mat(i * num_rows, 0), d);
1309 }
1310 }
1311
1312 template <typename TL, typename TR>
CreateDenseSlices(const typename SparseMatMul<TL,TR>::ConstMatrixMapR & mat,int row_start,int num_rows,int col_start,int num_cols,const DeviceBase::CpuWorkerThreads * thread_pool,MatrixR * buffer,std::vector<typename SparseMatMul<TL,TR>::ConstMatrixMapR * > * slices)1313 inline std::unique_ptr<BlockingCounter> SparseMatMul<TL, TR>::CreateDenseSlices(
1314 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int row_start,
1315 int num_rows, int col_start, int num_cols,
1316 const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer,
1317 std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
1318 std::unique_ptr<BlockingCounter> shuffle_counter(ShuffleMatrix(
1319 mat, row_start, num_rows, col_start, num_cols, N, thread_pool, buffer));
1320 const int num_slices = (num_cols + N - 1) / N;
1321 SliceMatrix(*buffer, num_rows, num_slices, slices);
1322 return shuffle_counter;
1323 }
1324
1325 template <typename TL, typename TR>
ComputeBlockSizes(const typename SparseMatMul<TL,TR>::ConstMatrixMapL & left,const typename SparseMatMul<TL,TR>::ConstMatrixMapR & right,bool transpose_left,int num_threads,int * KR,int * NR,int * KL,int * JB,int * IB)1326 inline void SparseMatMul<TL, TR>::ComputeBlockSizes(
1327 const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left,
1328 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right,
1329 bool transpose_left, int num_threads, int* KR, int* NR, int* KL, int* JB,
1330 int* IB) {
1331 // Heuristics for calculating block sizes
1332 // Assume two hyperthreads per core.
1333 const int est_num_cores = std::max(1, (num_threads + 1) / 2);
1334 // Use block of rhs with at most 128K floats per core.
1335 const int mem = est_num_cores * 128 * 1024;
1336 *KR = std::min(static_cast<int>(right.dimension(0)), mem / 256);
1337 *NR = right.dimension(1);
1338 if (*KR * *NR > mem) {
1339 // 4096 may be enough to amortize the cost of writes.
1340 *KR = std::min<int>(*KR, 4096);
1341 }
1342 // Use sizes that are multiples of K and 256.
1343 *KR = std::max(1, *KR / K) * K;
1344 *NR = std::max(1, *NR / 256) * 256;
1345 if (*KR * *NR > mem) {
1346 *NR = mem / *KR;
1347 }
1348 *NR = std::max(1, *NR / 256) * 256;
1349
1350 const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
1351 const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
1352 for (*KL = 1024; *KL > K; *KL /= 2) {
1353 if (*KR % *KL == 0 &&
1354 std::max<int>(1, left_dim0 / 64) * (left_dim1 / *KL) > est_num_cores) {
1355 break;
1356 }
1357 }
1358 DCHECK_EQ(*KL % K, 0);
1359 DCHECK_GE(*KR, *KL);
1360 if (*KR < right.dimension(0)) {
1361 CHECK_EQ(*KR % *KL, 0);
1362 }
1363
1364 *JB = std::max(1, static_cast<int>(sqrt(num_threads) / 2.0));
1365 *IB = 8 * *JB;
1366 DCHECK_EQ(N * sizeof(float) % 64, size_t{0});
1367 }
1368
1369 #ifdef TENSORFLOW_USE_LIBXSMM
1370
1371 template <typename F>
do_on_all_threads(const DeviceBase::CpuWorkerThreads * thread_pool,const F & f)1372 void do_on_all_threads(const DeviceBase::CpuWorkerThreads* thread_pool,
1373 const F& f) {
1374 int num_threads = thread_pool->num_threads;
1375 if (num_threads == 0) {
1376 LOG(FATAL) << "Have 0 threads in thread pool";
1377 } else if (num_threads == 1) {
1378 f(0);
1379 } else {
1380 BlockingCounter counter(num_threads - 1);
1381 for (int i = 1; i < num_threads; ++i) {
1382 thread_pool->workers->Schedule([&, i]() {
1383 f(i);
1384 counter.DecrementCount();
1385 });
1386 }
1387 f(0);
1388 counter.Wait();
1389 }
1390 }
1391
1392 template <typename T>
1393 struct empty_type_wrapper {};
1394
1395 // Copies of interface to libxsmm_spmdm_createSparseSlice_*_notrans_thread to
1396 // allow overloading
wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(empty_type_wrapper<float>,const libxsmm_spmdm_handle * handle,char transA,const float * A,libxsmm_CSR_sparseslice * libxsmm_output_csr_a,int block_id,int tid,int nthreads)1397 void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1398 empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
1399 const float* A, libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id,
1400 int tid, int nthreads) {
1401 return libxsmm_spmdm_createSparseSlice_fp32_thread(
1402 handle, transA, A, libxsmm_output_csr_a, block_id, tid, nthreads);
1403 }
wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(empty_type_wrapper<bfloat16>,const libxsmm_spmdm_handle * handle,char transA,const bfloat16 * A,libxsmm_CSR_sparseslice * libxsmm_output_csr_a,int block_id,int tid,int nthreads)1404 void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1405 empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle,
1406 char transA, const bfloat16* A,
1407 libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, int tid,
1408 int nthreads) {
1409 return libxsmm_spmdm_createSparseSlice_bfloat16_thread(
1410 handle, transA, reinterpret_cast<const libxsmm_bfloat16*>(A),
1411 libxsmm_output_csr_a, block_id, tid, nthreads);
1412 }
1413
wrapper_libxsmm_spmdm_compute_generic_thread(empty_type_wrapper<bfloat16>,const libxsmm_spmdm_handle * handle,char transA,char transB,const bfloat16 * alpha,libxsmm_CSR_sparseslice * A_sparse,const bfloat16 * B,char transC,const bfloat16 * beta,float * C,int block_id,int tid,int nthreads)1414 void wrapper_libxsmm_spmdm_compute_generic_thread(
1415 empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle,
1416 char transA, char transB, const bfloat16* alpha,
1417 libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC,
1418 const bfloat16* beta, float* C, int block_id, int tid, int nthreads) {
1419 return libxsmm_spmdm_compute_bfloat16_thread(
1420 handle, transA, transB, reinterpret_cast<const libxsmm_bfloat16*>(alpha),
1421 A_sparse, reinterpret_cast<const libxsmm_bfloat16*>(B), transC,
1422 reinterpret_cast<const libxsmm_bfloat16*>(beta), C, block_id, tid,
1423 nthreads);
1424 }
wrapper_libxsmm_spmdm_compute_generic_thread(empty_type_wrapper<float>,const libxsmm_spmdm_handle * handle,char transA,char transB,const float * alpha,libxsmm_CSR_sparseslice * A_sparse,const float * B,char transC,const float * beta,float * C,int block_id,int tid,int nthreads)1425 void wrapper_libxsmm_spmdm_compute_generic_thread(
1426 empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
1427 char transB, const float* alpha, libxsmm_CSR_sparseslice* A_sparse,
1428 const float* B, char transC, const float* beta, float* C, int block_id,
1429 int tid, int nthreads) {
1430 return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, alpha,
1431 A_sparse, B, transC, beta, C,
1432 block_id, tid, nthreads);
1433 }
1434
1435 template <typename TL, typename TR>
Compute(typename LibxsmmSparseMatMul<TL,TR>::TensorInfoCache * cache,const typename LibxsmmSparseMatMul<TL,TR>::ConstMatrixMapL & left,const typename LibxsmmSparseMatMul<TL,TR>::ConstMatrixMapR & right,bool transpose_left,const DeviceBase::CpuWorkerThreads * thread_pool,bool transpose_output,MatrixMap * output)1436 inline void LibxsmmSparseMatMul<TL, TR>::Compute(
1437 typename LibxsmmSparseMatMul<TL, TR>::TensorInfoCache* cache,
1438 const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapL& left,
1439 const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapR& right,
1440 bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
1441 bool transpose_output, MatrixMap* output) {
1442 const int num_threads = thread_pool->num_threads;
1443 const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
1444 const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
1445 const int right_dim0 = right.dimension(0);
1446 const int right_dim1 = right.dimension(1);
1447 CHECK_EQ(left_dim1, right_dim0);
1448 CHECK_EQ(left_dim0,
1449 (transpose_output ? output->dimension(1) : output->dimension(0)));
1450 CHECK_EQ(right_dim1,
1451 (transpose_output ? output->dimension(0) : output->dimension(1)));
1452 #if 0 // this issue seems to be resolved
1453 if (left_dim0 < 32 || left_dim1 < 32 || right_dim1 < 32) {
1454 // Causes problems in libxsmm
1455 SparseMatMul<TL, TR>::Compute(
1456 nullptr /* Assumes no cached data for fallback */, left, right,
1457 transpose_left, thread_pool, transpose_output, output);
1458 return;
1459 }
1460 #endif
1461 auto left_data = left.data();
1462 auto right_data = right.data();
1463 auto output_data = output->data();
1464 // Initialize libxsmm for this matrix; make sure another thread doesn't use
1465 // this handle
1466 auto entry =
1467 cache->take_cache_entry(left_dim0, right_dim0, right_dim1, num_threads);
1468 // Convert the left matrix to compressed sparse row (CSR) format
1469 ptrdiff_t total_num_creation_blocks =
1470 libxsmm_spmdm_get_num_createSparseSlice_blocks(&entry->handle);
1471 std::atomic<int> cur_create_block_number;
1472 cur_create_block_number.store(0);
1473 do_on_all_threads(thread_pool, [&](int i) {
1474 while (true) {
1475 int work_item = cur_create_block_number.fetch_add(1);
1476 if (work_item >= total_num_creation_blocks) break;
1477 wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1478 empty_type_wrapper<TL>{}, &entry->handle,
1479 (transpose_left ? 'T' : 'N'), left_data, entry->output_csr, work_item,
1480 i, num_threads);
1481 }
1482 });
1483 // Do matrix-matrix multiplication
1484 ptrdiff_t total_num_mult_blocks =
1485 libxsmm_spmdm_get_num_compute_blocks(&entry->handle);
1486 std::atomic<int> cur_mult_block_number;
1487 cur_mult_block_number.store(0);
1488 do_on_all_threads(thread_pool, [&](int i) {
1489 while (true) {
1490 int work_item = cur_mult_block_number.fetch_add(1);
1491 if (work_item >= total_num_mult_blocks) break;
1492 const TL alpha(1.0); // Stored in a variable so we can get a pointer
1493 const TL beta(0.0); // Stored in a variable so we can get a pointer
1494 wrapper_libxsmm_spmdm_compute_generic_thread(
1495 empty_type_wrapper<TL>{}, &entry->handle,
1496 (transpose_left ? 'T' : 'N'), 'N', &alpha, entry->output_csr,
1497 right_data, (transpose_output ? 'T' : 'N'), &beta, output_data,
1498 work_item, i, num_threads);
1499 }
1500 });
1501 // Put handle + CSR storage back into cache
1502 cache->return_cache_entry(std::move(entry));
1503 }
1504
1505 #endif // TENSORFLOW_USE_LIBXSMM
1506
1507 // Here is an overview of the SparseMatMul code. Note that we assume that the
1508 // left matrix is sparse.
1509 //
1510 // The matrix "left" is divided into a grid with blocksize of (M, KL). Each
1511 // block is encoded as a SparseSlice. These grid elements are stored as
1512 // std::vector<std::vector<SparseSlice>>. Each element of the outer vector
1513 // represents M rows of the left matrix. Lets call these elements l_i and lets
1514 // call each element of the inner vector L_mk.
1515 //
1516 // The matrix "right" is divided into a grid with block size KR * NR. Lets
1517 // denote the blocks on the right as R_kn. Note that we ensure that KL divides
1518 // KR so that for each element R_kn, we don't need to multiply it with any
1519 // partial L_mk blocks.
1520 //
1521 // We then multiply each right side block R_kn with the full "left" matrix and
1522 // update the output. These iterations are run sequentially since R_kn are
1523 // packed into the same underlying temporary buffer.
1524 //
1525 // In each iteration we do the following:
1526 // 1. Create slices r_j of R_kn: We split R_kn into vertical blocks with N
1527 // (=128) columns and then concatenating these slices into a buffer. This is
1528 // done so that each slice r_j of R_kn is stored contiguously in memory. Note
1529 // that if R_kj has dimensions (KR, NR), we create NR / N slices, and the
1530 // buffer has dimensions (KR * NR / N, N) (assuming N divides NR).
1531 // 2. For each (l_i, r_j), we compute the inner product using the GEPP function
1532 // and update the output block o_ij. These calls are further blocked to
1533 // reduce the working set size. In each iteration we take IB elements from
1534 // {l_i} and JB elements from {r_j} and compute the IB * JB inner products.
1535 template <typename TL, typename TR>
Compute(typename SparseMatMul<TL,TR>::TensorInfoCache *,const typename SparseMatMul<TL,TR>::ConstMatrixMapL & left,const typename SparseMatMul<TL,TR>::ConstMatrixMapR & right,bool transpose_left,const DeviceBase::CpuWorkerThreads * thread_pool,bool transpose_output,MatrixMap * output)1536 inline void SparseMatMul<TL, TR>::Compute(
1537 typename SparseMatMul<TL, TR>::TensorInfoCache* /*cache*/,
1538 const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left,
1539 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right,
1540 bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
1541 bool transpose_output, MatrixMap* output) {
1542 const int num_threads = thread_pool->num_threads;
1543 int KR, NR, KL, JB, IB;
1544 ComputeBlockSizes(left, right, transpose_left, num_threads, &KR, &NR, &KL,
1545 &JB, &IB);
1546 // Slice the left matrix
1547 std::vector<std::vector<SparseSlice<TL>*>> left_slices;
1548 std::unique_ptr<BlockingCounter> sparse_slice_counter =
1549 CreateSparseSlices(ConstMatrixMapL(left.data(), left.dimensions()),
1550 transpose_left, M, K, KL, &left_slices, thread_pool);
1551 const int num_left_slices = left_slices.size();
1552
1553 const int right_dim0 = right.dimension(0);
1554 const int right_dim1 = right.dimension(1);
1555 // Allocate buffer for storing slices of right matrix.
1556 // Note buffer needs enough space to hold at most a KR * NR matrix since that
1557 // is the block size per iteration.
1558 const int buffer_num_rows =
1559 std::min(KR, right_dim0) * ((std::min(NR, right_dim1) + N - 1) / N);
1560 MatrixR buffer(buffer_num_rows, N);
1561 std::vector<ConstMatrixMapR*> right_slices;
1562
1563 std::vector<SparseSlice<TL>*> block_left_slices;
1564 std::vector<std::function<void(void)>> tasks;
1565 // Number of blocks based on block sizes of KR * NR.
1566 const int num_k_blocks = (right_dim0 + KR - 1) / KR;
1567 const int num_n_blocks = (right_dim1 + NR - 1) / NR;
1568 std::unique_ptr<BlockingCounter> dense_slice_counter;
1569
1570 for (int nb = 0; nb < num_n_blocks; ++nb) {
1571 const int right_num_cols =
1572 std::min(NR, static_cast<int>(right_dim1 - NR * nb));
1573 for (int kb = 0; kb < num_k_blocks; ++kb) {
1574 const int right_num_rows =
1575 std::min(KR, static_cast<int>(right_dim0 - KR * kb));
1576 dense_slice_counter = CreateDenseSlices(
1577 right, kb * KR, right_num_rows, nb * NR, right_num_cols, thread_pool,
1578 &buffer, &right_slices);
1579 const int num_right_slices = right_slices.size();
1580 tasks.reserve(num_left_slices * num_right_slices);
1581 for (int j_outer = 0; j_outer < num_right_slices; j_outer += JB) {
1582 for (int i_outer = 0; i_outer < num_left_slices; i_outer += IB) {
1583 for (int j_inner = j_outer;
1584 j_inner < std::min(num_right_slices, j_outer + JB); ++j_inner) {
1585 const int num_cols = std::min(N, right_num_cols - N * j_inner);
1586 for (int i_inner = i_outer;
1587 i_inner < std::min(num_left_slices, i_outer + IB); ++i_inner) {
1588 block_left_slices.clear();
1589 int begin = kb * KR / KL;
1590 int end = std::min<int>((kb + 1) * KR / KL,
1591 (right.dimension(0) + KL - 1) / KL);
1592 DCHECK_LT(begin, end);
1593 block_left_slices.insert(block_left_slices.begin(),
1594 left_slices[i_inner].begin() + begin,
1595 left_slices[i_inner].begin() + end);
1596 tasks.push_back(std::bind(
1597 &ComputeOutputBlock, block_left_slices,
1598 std::ref(*right_slices[j_inner]), num_cols, M * i_inner,
1599 N * j_inner + nb * NR, kb == 0, transpose_output, output));
1600 }
1601 }
1602 }
1603 }
1604 if (sparse_slice_counter) {
1605 sparse_slice_counter->Wait();
1606 sparse_slice_counter.reset(nullptr);
1607 }
1608 if (dense_slice_counter) {
1609 dense_slice_counter->Wait();
1610 dense_slice_counter.reset(nullptr);
1611 }
1612 BlockingCounter bc(tasks.size());
1613 for (const auto& t : tasks) {
1614 thread_pool->workers->Schedule([&bc, &t]() {
1615 t();
1616 bc.DecrementCount();
1617 });
1618 }
1619 bc.Wait();
1620 tasks.clear();
1621 for (auto& temp : right_slices) {
1622 delete temp;
1623 }
1624 right_slices.clear();
1625 }
1626 }
1627 for (auto& left_slice : left_slices) {
1628 for (auto& temp : left_slice) {
1629 delete temp;
1630 }
1631 left_slice.clear();
1632 }
1633 }
1634
1635 #define REGISTER_SPARSE_MATMUL(TA, TB) \
1636 REGISTER_KERNEL_BUILDER(Name("SparseMatMul") \
1637 .Device(DEVICE_CPU) \
1638 .TypeConstraint<TA>("Ta") \
1639 .TypeConstraint<TB>("Tb"), \
1640 SparseMatMulOp<TA, TB, SparseMatMul>);
1641 #ifdef TENSORFLOW_USE_LIBXSMM
1642 #define REGISTER_SPARSE_MATMUL_LIBXSMM(TA, TB) \
1643 REGISTER_KERNEL_BUILDER(Name("SparseMatMul") \
1644 .Device(DEVICE_CPU) \
1645 .TypeConstraint<TA>("Ta") \
1646 .TypeConstraint<TB>("Tb"), \
1647 SparseMatMulOp<TA, TB, LibxsmmSparseMatMul>);
1648 #endif
1649
1650 REGISTER_SPARSE_MATMUL(float, bfloat16);
1651 REGISTER_SPARSE_MATMUL(bfloat16, float);
1652
1653 #ifdef TENSORFLOW_USE_LIBXSMM
1654 REGISTER_SPARSE_MATMUL_LIBXSMM(bfloat16, bfloat16);
1655 REGISTER_SPARSE_MATMUL_LIBXSMM(float, float);
1656 #else
1657 REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
1658 REGISTER_SPARSE_MATMUL(float, float);
1659 #endif
1660
1661 #undef REGISTER_SPARSE_MATMUL
1662
1663 } // end namespace tensorflow
1664