1 #include <ATen/ATen.h>
2 #include <ATen/core/Tensor.h>
3 #include <ATen/cuda/CUDAUtils.h>
4 #include <ATen/Dispatch.h>
5
6 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
7 #else
8 #include <cuda_runtime.h>
9 #include <cutlass/cutlass.h>
10 #include <cutlass/layout/layout.h>
11 #include <cutlass/tensor_ref.h>
12 #include <cutlass/gemm/device/gemm_sparse_with_visitor.h>
13 #include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
14 #endif
15
16 #include <type_traits>
17 #include <tuple>
18
19 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
20 #else
21 #define CUTLASS_STATUS_CHECK(status) \
22 { \
23 TORCH_CHECK(status == cutlass::Status::kSuccess, \
24 "Got CUTLASS error: ", cutlassGetStatusString(status)); \
25 }
26
27 namespace {
28 enum class Activation{NONE, RELU, SILU};
29 }
30 #endif
31
32 namespace at::native {
33
34 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
35 #else
36 // Wrapper function for CUTLASS sparse GEMM implementation, used
37 // solely to simplify dispatching from
38 // _sparse_semi_structured_linear() function below.
39 template <
40 typename ElementInputA,
41 typename ElementInputB,
42 typename ElementOutput,
43 typename ElementAccumulator,
44 typename ThreadblockShape,
45 typename WarpShape,
46 typename InstructionShape,
47 typename LayoutInputA,
48 typename LayoutInputB,
49 bool use_bias,
50 Activation activation>
51 Tensor two_four_sgemm(
52 const Tensor& tensor_a,
53 const at::IntArrayRef::value_type& tensor_a_stride,
54 const Tensor& tensor_b,
55 const at::IntArrayRef::value_type& tensor_b_stride,
56 const Tensor& tensor_c, const Tensor& meta) {
57 // Fix CUTLASS sparse GEMM template arguments that are not
58 // provided as template argument of this function, and create an
59 // alias for particular instantiation of this template.
60 using LayoutOutput = cutlass::layout::RowMajor; // Result of the operation will be provided in row-major format.
61 using MMAOp = cutlass::arch::OpClassTensorOp; // Tensor cores are to be used for maximum performance.
62 using SmArch = cutlass::arch::Sm80; // Only CC 8.x devices are supported at the moment.
63 using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This choice provides good performance across wide range of operand sizes.
64 constexpr int NumStages = 3; // This choice provides good performance across wide range of operand sizes.
65 using Operator = cutlass::arch::OpMultiplyAdd;
66 constexpr int NumEVTEpilogueStages = 1;
67
68 constexpr int AlignmentInputA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
69 constexpr int AlignmentInputB = 128 / cutlass::sizeof_bits<ElementInputB>::value;
70 constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
71
72 using ElementComputeEpilogue = ElementAccumulator;
73 constexpr int AlignmentComputeEpilogue = 128 / cutlass::sizeof_bits<ElementComputeEpilogue>::value;
74 using ElementC = ElementOutput;
75 using LayoutC = LayoutOutput;
76 constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
77
78 using BiasTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
79 ThreadblockShape,
80 WarpShape,
81 ElementC,
82 AlignmentC,
83 NumEVTEpilogueStages>;
84 using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
85 ThreadblockShape,
86 WarpShape,
87 ElementOutput,
88 AlignmentOutput,
89 NumEVTEpilogueStages>;
90
91 using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
92
93 using BiasScalar =
94 cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementC>;
95 using BiasTensor =
96 cutlass::epilogue::threadblock::VisitorColBroadcast<
97 BiasTileThreadMap,
98 ElementC,
99 cute::Stride<cute::_1, cute::_0, int64_t>>;
100 using Bias = std::conditional_t<use_bias, BiasTensor, BiasScalar>;
101 using BiasArguments = typename Bias::Arguments;
102
103 using ApplyBias = cutlass::epilogue::threadblock::VisitorCompute<
104 cutlass::plus, ElementComputeEpilogue, ElementComputeEpilogue,
105 cutlass::FloatRoundStyle::round_to_nearest>;
106 using EVTApplyBias = cutlass::epilogue::threadblock::Sm80EVT<
107 ApplyBias,
108 Accum,
109 Bias>;
110
111 using ApplyActivationNone = cutlass::epilogue::threadblock::VisitorCompute<
112 cutlass::epilogue::thread::Identity,
113 ElementComputeEpilogue,
114 ElementComputeEpilogue,
115 cutlass::FloatRoundStyle::round_to_nearest>;
116 using ApplyActivationReLu = cutlass::epilogue::threadblock::VisitorCompute<
117 cutlass::epilogue::thread::ReLu,
118 ElementComputeEpilogue,
119 ElementComputeEpilogue,
120 cutlass::FloatRoundStyle::round_to_nearest>;
121 using ApplyActivationSiLu = cutlass::epilogue::threadblock::VisitorCompute<
122 cutlass::epilogue::thread::SiLu,
123 ElementComputeEpilogue,
124 ElementComputeEpilogue,
125 cutlass::FloatRoundStyle::round_to_nearest>;
126 using ApplyActivation =
127 std::conditional_t<
128 activation == Activation::NONE,
129 ApplyActivationNone,
130 std::conditional_t<
131 activation == Activation::RELU,
132 ApplyActivationReLu,
133 ApplyActivationSiLu>>;
134 using EVTApplyActivation = cutlass::epilogue::threadblock::Sm80EVT<
135 ApplyActivation,
136 EVTApplyBias>;
137
138 using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
139 OutputTileThreadMap, ElementOutput, cutlass::FloatRoundStyle::round_to_nearest,
140 cute::Stride<int64_t, cute::_1, int64_t>>;
141
142 using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<
143 Output,
144 EVTApplyActivation>;
145
146 using Gemm = cutlass::gemm::device::SparseGemmWithVisitor<
147 ElementInputA,
148 LayoutInputA,
149 ElementInputB,
150 LayoutInputB,
151 ElementC,
152 LayoutC,
153 ElementAccumulator,
154 MMAOp,
155 SmArch,
156 ThreadblockShape,
157 WarpShape,
158 InstructionShape,
159 EVTOutput,
160 SwizzleThreadBlock,
161 NumStages,
162 AlignmentInputA,
163 AlignmentInputB,
164 Operator,
165 NumEVTEpilogueStages>;
166
167 // Datatype and layout of metadata matrix are inferred from sparse
168 // GEMM template.
169 using ElementInputE = typename Gemm::ElementE;
170 using LayoutInputE = cutlass::layout::RowMajor;
171 using ReorderedLayoutInputE = typename Gemm::LayoutE;
172 static_assert(
173 std::is_same<ReorderedLayoutInputE,
174 cutlass::layout::ColumnMajorInterleaved<2>>::value,
175 "Matrix layout used by CUTLASS for reordered metadata for sparse GEMM "
176 "change, thus code doing conversions from/to dense matrix has to be "
177 "updated.");
178
179 constexpr auto kSparse = Gemm::kSparse;
180 constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
181
182 // Operand sizes.
183 const int length_m = tensor_a.size(0);
184 const int length_k = tensor_b.size(0);
185 const int length_n = tensor_b.size(1);
186 const auto meta_ncols = length_k / kSparse / kElementsPerElementE;
187
188 // Determine PyTorch datatype for the metadata matrix.
189 auto meta_dtype = at::kChar;
190 switch (sizeof(ElementInputE)) {
191 case 2:
192 meta_dtype = at::kShort;
193 break;
194 case 4:
195 meta_dtype = at::kInt;
196 break;
197 default:
198 AT_ERROR("two_four_sgemm: invalid size of meta tensor datatype "
199 "encountered");
200 }
201 TORCH_CHECK(meta.dtype() == meta_dtype,
202 "two_four_sgemm: Expected meta datatype ", meta_dtype,
203 ", but got ", meta.dtype());
204
205 // Determine PyTorch datatype for the output matrix.
206 auto tensor_d_dtype = at::kChar;
207 if constexpr (std::is_same_v<ElementOutput, int8_t>) {
208 tensor_d_dtype = at::kChar;
209 } else if constexpr (std::is_same_v<ElementOutput, int32_t>) {
210 tensor_d_dtype = at::kInt;
211 } else if constexpr (std::is_same_v<ElementOutput, cutlass::half_t>) {
212 tensor_d_dtype = at::kHalf;
213 } else if constexpr (std::is_same_v<ElementOutput, cutlass::bfloat16_t>) {
214 tensor_d_dtype = at::kBFloat16;
215 } else if constexpr (std::is_same_v<ElementOutput, float>) {
216 tensor_d_dtype = at::kFloat;
217 } else {
218 AT_ERROR("two_four_sgemm: invalid datatype for sparse GEMM output ",
219 "encountered");
220 }
221 if constexpr (use_bias) {
222 TORCH_CHECK(tensor_c.dtype() == tensor_d_dtype,
223 "two_four_sgemm: Expected sparse GEMM bias datatype ",
224 tensor_d_dtype, ", but got ", tensor_c.dtype());
225 }
226
227 // Create output matrix.
228 Tensor tensor_d =
229 tensor_a.new_empty({length_m, length_n},
230 at::TensorOptions().dtype(tensor_d_dtype));
231
232 // Prepare arguments for CUTLASS sparse GEMM kernel.
233 cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
234 LayoutInputA layout_a(tensor_a_stride);
235 LayoutInputB layout_b(tensor_b_stride);
236 auto tensor_a_device_ref =
237 cutlass::TensorRef<ElementInputA, LayoutInputA>(
238 (ElementInputA*)tensor_a.data_ptr(), layout_a);
239 auto tensor_b_device_ref =
240 cutlass::TensorRef<ElementInputB, LayoutInputB>(
241 (ElementInputB*)tensor_b.data_ptr(), layout_b);
242 auto tensor_e_reordered_device_ref =
243 cutlass::TensorRef<ElementInputE, ReorderedLayoutInputE>(
244 (ElementInputE*)meta.data_ptr(),
245 ReorderedLayoutInputE::packed({length_m, meta_ncols}));
246
247 BiasArguments bias_arguments{
248 [&]() -> BiasArguments {
249 if constexpr (use_bias) {
250 return {(ElementC*)tensor_c.data_ptr(),
251 ElementC(0),
252 {cute::_1{}, cute::_0{}, problem_size.m()}};
253 } else {
254 return {ElementC(0)};
255 }
256 }()
257 };
258 typename Output::Arguments output_arguments{
259 (ElementOutput*)tensor_d.data_ptr(),
260 {problem_size.n(), cute::_1{}, problem_size.mn().product()}
261 };
262 typename EVTOutput::Arguments callback_arguments{
263 {
264 {
265 {}, // Accum
266 bias_arguments, // Bias
267 {} // ApplyBias
268 }, // EVTApplyBias
269 {} // ApplyActivation
270 }, // EVTApplyActivation
271 output_arguments, // Output
272 }; // EVTOutput
273
274 // Create a tuple of CUTLASS sparse GEMM kernel arguments.
275 typename Gemm::Arguments arguments{
276 problem_size,
277 tensor_a_device_ref,
278 tensor_b_device_ref,
279 tensor_e_reordered_device_ref,
280 callback_arguments};
281
282 cutlass::Status status;
283
284 // Create CUTLASS sparse GEMM kernel object.
285 Gemm gemm_op;
286
287 // Verify that sparse GEMM operation with given arguments can be
288 // performed by CUTLASS.
289 status = gemm_op.can_implement(arguments);
290 CUTLASS_STATUS_CHECK(status);
291
292 // Allocate workspace for CUTLASS sparse GEMM kernel.
293 const auto workspace_size = Gemm::get_workspace_size(arguments);
294 auto workspace = tensor_a.new_empty({(int64_t)workspace_size},
295 at::TensorOptions().dtype(at::kByte));
296
297 // Initialize CUTLASS sparse GEMM object.
298 status = gemm_op.initialize(arguments, workspace.data_ptr(),
299 at::cuda::getCurrentCUDAStream());
300 CUTLASS_STATUS_CHECK(status);
301
302 // Perform sparse GEMM operation.
303 status = gemm_op.run(at::cuda::getCurrentCUDAStream());
304 CUTLASS_STATUS_CHECK(status);
305
306 C10_CUDA_KERNEL_LAUNCH_CHECK();
307
308 return tensor_d;
309 }
310
311 // Dispatch according to the input tensors layouts combination.
312 template <
313 typename ElementInputA,
314 typename ElementInputB,
315 typename ElementOutput,
316 typename ElementAccumulator,
317 typename ThreadblockShape,
318 typename WarpShape,
319 typename InstructionShape,
320 bool EnableRowMajorRowMajorLayouts,
321 bool EnableRowMajorColumnMajorLayouts,
322 bool EnableColumnMajorRowMajorLayouts,
323 bool EnableColumnMajorColumnMajorLayouts,
324 bool use_bias,
325 Activation activation>
326 Tensor two_four_sgemm_dispatch_layouts(
327 const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
328 const Tensor& meta) {
329 // Determine layouts (row-major or column-major) of input tensors.
330 const auto strides_a = tensor_a.strides();
331 auto tensor_a_row_major = strides_a[1] == 1;
332 auto tensor_a_stride = tensor_a_row_major ? strides_a[0] : strides_a[1];
333 const auto strides_b = tensor_b.strides();
334 auto tensor_b_row_major = strides_b[1] == 1;
335 auto tensor_b_stride = tensor_b_row_major ? strides_b[0] : strides_b[1];
336
337 // Perform dispatching.
338 if constexpr (EnableRowMajorRowMajorLayouts) {
339 if (tensor_a_row_major && tensor_b_row_major) {
340 return two_four_sgemm<
341 ElementInputA,
342 ElementInputB,
343 ElementOutput,
344 ElementAccumulator,
345 ThreadblockShape,
346 WarpShape,
347 InstructionShape,
348 cutlass::layout::RowMajor,
349 cutlass::layout::RowMajor,
350 use_bias,
351 activation>(
352 tensor_a,
353 tensor_a_stride,
354 tensor_b,
355 tensor_b_stride,
356 tensor_c,
357 meta);
358 }
359 }
360 if constexpr (EnableRowMajorColumnMajorLayouts) {
361 if (tensor_a_row_major && !tensor_b_row_major) {
362 return two_four_sgemm<
363 ElementInputA,
364 ElementInputB,
365 ElementOutput,
366 ElementAccumulator,
367 ThreadblockShape,
368 WarpShape,
369 InstructionShape,
370 cutlass::layout::RowMajor,
371 cutlass::layout::ColumnMajor,
372 use_bias,
373 activation>(
374 tensor_a,
375 tensor_a_stride,
376 tensor_b,
377 tensor_b_stride,
378 tensor_c,
379 meta);
380 }
381 }
382 if constexpr (EnableColumnMajorRowMajorLayouts) {
383 if (!tensor_a_row_major && tensor_b_row_major) {
384 return two_four_sgemm<
385 ElementInputA,
386 ElementInputB,
387 ElementOutput,
388 ElementAccumulator,
389 ThreadblockShape,
390 WarpShape,
391 InstructionShape,
392 cutlass::layout::ColumnMajor,
393 cutlass::layout::RowMajor,
394 use_bias,
395 activation>(
396 tensor_a,
397 tensor_a_stride,
398 tensor_b,
399 tensor_b_stride,
400 tensor_c,
401 meta);
402 }
403 }
404 if constexpr (EnableColumnMajorColumnMajorLayouts) {
405 if (!tensor_a_row_major && !tensor_b_row_major) {
406 return two_four_sgemm<
407 ElementInputA,
408 ElementInputB,
409 ElementOutput,
410 ElementAccumulator,
411 ThreadblockShape,
412 WarpShape,
413 InstructionShape,
414 cutlass::layout::ColumnMajor,
415 cutlass::layout::ColumnMajor,
416 use_bias,
417 activation>(
418 tensor_a,
419 tensor_a_stride,
420 tensor_b,
421 tensor_b_stride,
422 tensor_c,
423 meta);
424 }
425 }
426
427 AT_ERROR("two_four_sgemm_dispatch_layouts: Combination of ",
428 tensor_a_row_major ? "row-major" : "column_major", " and ",
429 tensor_b_row_major ? "row-major" : "column_major",
430 " layouts for input tensors is not supported");
431 return Tensor{};
432 }
433
434 // Dispatch according to the bias tensor being provided or not.
435 template <
436 typename ElementInputA,
437 typename ElementInputB,
438 typename ElementOutput,
439 typename ElementAccumulator,
440 typename ThreadblockShape,
441 typename WarpShape,
442 typename InstructionShape,
443 bool EnableRowMajorRowMajorLayouts,
444 bool EnableRowMajorColumnMajorLayouts,
445 bool EnableColumnMajorRowMajorLayouts,
446 bool EnableColumnMajorColumnMajorLayouts,
447 Activation activation>
448 Tensor two_four_sgemm_dispatch_layouts_bias(
449 const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
450 const Tensor& meta) {
451 if (tensor_c.numel() > 0) {
452 return two_four_sgemm_dispatch_layouts<
453 ElementInputA,
454 ElementInputB,
455 ElementOutput,
456 ElementAccumulator,
457 ThreadblockShape,
458 WarpShape,
459 InstructionShape,
460 EnableRowMajorRowMajorLayouts,
461 EnableRowMajorColumnMajorLayouts,
462 EnableColumnMajorRowMajorLayouts,
463 EnableColumnMajorColumnMajorLayouts,
464 true,
465 activation>(
466 tensor_a,
467 tensor_b,
468 tensor_c,
469 meta);
470 } else {
471 return two_four_sgemm_dispatch_layouts<
472 ElementInputA,
473 ElementInputB,
474 ElementOutput,
475 ElementAccumulator,
476 ThreadblockShape,
477 WarpShape,
478 InstructionShape,
479 EnableRowMajorRowMajorLayouts,
480 EnableRowMajorColumnMajorLayouts,
481 EnableColumnMajorRowMajorLayouts,
482 EnableColumnMajorColumnMajorLayouts,
483 false,
484 activation>(
485 tensor_a,
486 tensor_b,
487 tensor_c,
488 meta);
489 }
490 }
491
492 // Dispatch according to the activation functions enabled.
493 template <
494 typename ElementInputA,
495 typename ElementInputB,
496 typename ElementOutput,
497 typename ElementAccumulator,
498 typename ThreadblockShape,
499 typename WarpShape,
500 typename InstructionShape,
501 bool EnableRowMajorRowMajorLayouts,
502 bool EnableRowMajorColumnMajorLayouts,
503 bool EnableColumnMajorRowMajorLayouts,
504 bool EnableColumnMajorColumnMajorLayouts,
505 bool EnableActivationNone,
506 bool EnableActivationReLU,
507 bool EnableActivationSiLU>
508 Tensor two_four_sgemm_dispatch_layouts_bias_activation(
509 const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
510 const Tensor& meta, const c10::string_view& activation) {
511 // Perform dispatching.
512 if constexpr (EnableActivationNone) {
513 if (activation == "none") {
514 return two_four_sgemm_dispatch_layouts_bias<
515 ElementInputA,
516 ElementInputB,
517 ElementOutput,
518 ElementAccumulator,
519 ThreadblockShape,
520 WarpShape,
521 InstructionShape,
522 EnableRowMajorRowMajorLayouts,
523 EnableRowMajorColumnMajorLayouts,
524 EnableColumnMajorRowMajorLayouts,
525 EnableColumnMajorColumnMajorLayouts,
526 Activation::NONE>(
527 tensor_a,
528 tensor_b,
529 tensor_c,
530 meta);
531 }
532 }
533 if constexpr (EnableActivationReLU) {
534 if (activation == "relu") {
535 return two_four_sgemm_dispatch_layouts_bias<
536 ElementInputA,
537 ElementInputB,
538 ElementOutput,
539 ElementAccumulator,
540 ThreadblockShape,
541 WarpShape,
542 InstructionShape,
543 EnableRowMajorRowMajorLayouts,
544 EnableRowMajorColumnMajorLayouts,
545 EnableColumnMajorRowMajorLayouts,
546 EnableColumnMajorColumnMajorLayouts,
547 Activation::RELU>(
548 tensor_a,
549 tensor_b,
550 tensor_c,
551 meta);
552 }
553 }
554 if constexpr (EnableActivationSiLU) {
555 if (activation == "silu") {
556 return two_four_sgemm_dispatch_layouts_bias<
557 ElementInputA,
558 ElementInputB,
559 ElementOutput,
560 ElementAccumulator,
561 ThreadblockShape,
562 WarpShape,
563 InstructionShape,
564 EnableRowMajorRowMajorLayouts,
565 EnableRowMajorColumnMajorLayouts,
566 EnableColumnMajorRowMajorLayouts,
567 EnableColumnMajorColumnMajorLayouts,
568 Activation::SILU>(
569 tensor_a,
570 tensor_b,
571 tensor_c,
572 meta);
573 }
574 }
575
576 AT_ERROR("two_four_sgemm_dispatch_layouts: Activation \"", activation,
577 "\" is not supported for given input tensors");
578 return Tensor{};
579 }
580 #endif
581
582 // Perform linear transformation, but using corresponding CUTLASS
583 // sparse GEMM kernel, to given arguments:
584 // output = input * weight.T + bias
585 // The "input" tensor is a dense tensor, while the "weight" tensor is
586 // a matrix with 2:4 sparsity pattern. The "bias" tensor is optional;
587 // if provided, it should be a vector, with the number of elements
588 // equal to the number of rows of "weight" matrix. It is assumed
589 // that. It is assumed that "input", after squashing eventual batch
590 // dimensions with the next-to-last dimension of this tensor, and
591 // "weight" tensors are supplied either in row-major or column-major
592 // layouts (different layouts between these two tensors are OK, but
593 // not all combinations of formats are supported for some datatypes of
594 // these matrices). The "meta" argument contains metadata matrix. The
595 // function returns the output tensor.
596 //
597 // There exists numerous limitations of CUTLASS sparse GEMM kernel,
598 // with regards to sizes and alignments of input tensors, their
599 // layouts and datatypes, and so on; this is the reason for large
600 // number of checks throughout the code.
_sparse_semi_structured_linear(const Tensor & input,const Tensor & weight,const Tensor & meta,const std::optional<Tensor> & bias_opt,const std::optional<c10::string_view> activation_opt,const std::optional<c10::ScalarType> out_dtype_opt)601 Tensor _sparse_semi_structured_linear(
602 const Tensor& input, const Tensor& weight,
603 const Tensor& meta, const std::optional<Tensor>& bias_opt,
604 const std::optional<c10::string_view> activation_opt,
605 const std::optional<c10::ScalarType> out_dtype_opt) {
606 TORCH_WARN_ONCE("_sparse_semi_structured_linear is deprecated and will be "
607 "removed in a future PyTorch release. Please use "
608 "_sparse_semi_structured_mm/_sparse_semi_structured_addmm "
609 "instead.");
610 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
611 AT_ERROR("_sparse_semi_structured_linear: CUTLASS not supported");
612 return Tensor{};
613 #else
614 // No need to check that all tensors are on CUDA device, as this
615 // is provided by dispatch.
616
617 // Introduce alias names for arguments, according to the CUTLASS
618 // naming conventions. Also, squash the batch dimensions of the
619 // input tensor with its next-to-last dimensions.
620 const auto input_sizes = input.sizes().vec();
621 const auto tensor_a = weight;
622 const auto tensor_b =
623 input.reshape({-1, input_sizes.back()}).transpose(-1, -2);
624 const auto tensor_c = bias_opt.has_value() ? *bias_opt : Tensor{};
625
626 const auto activation =
627 activation_opt.has_value() ? *activation_opt : "none";
628
629 TORCH_CHECK(!out_dtype_opt.has_value() ||
630 (tensor_a.dtype() == at::ScalarType::Char &&
631 out_dtype_opt.value() == at::ScalarType::Int),
632 "_sparse_semi_structured_linear: Setting out_dtype is only "
633 "supported for int8 input and int32 output");
634
635 // For now, only CC 8.x devices are supported.
636 const auto dprops = at::cuda::getCurrentDeviceProperties();
637 const auto is_sm8x = dprops->major == 8;
638 TORCH_CHECK(is_sm8x,
639 "_sparse_semi_structured_linear: Supported only on GPUs with "
640 "compute capability 8.x");
641
642 // Validate datatypes of input tensors.
643 TORCH_CHECK(tensor_a.dtype() == at::kChar ||
644 tensor_a.dtype() == at::kHalf ||
645 tensor_a.dtype() == at::kBFloat16 ||
646 tensor_a.dtype() == at::kFloat,
647 "_sparse_semi_structured_linear: The weight datatype ",
648 tensor_a.dtype(), " is not supported");
649 TORCH_CHECK(tensor_b.dtype() == tensor_a.dtype(),
650 "_sparse_semi_structured_linear: Expected input datatype ",
651 tensor_a.dtype(), ", but got ", tensor_b.dtype());
652
653 // Validate layouts of input tensors.
654 TORCH_CHECK(tensor_a.layout() == Layout::Strided,
655 "_sparse_semi_structured_linear: Expected weight argument "
656 "to be strided, but got layout ", tensor_a.layout());
657 TORCH_CHECK(tensor_a.dim() == 2,
658 "_sparse_semi_structured_linear: Expected weight argument "
659 "to be 2D tensor, got ", tensor_a.dim(), " dims");
660 const auto strides_a = tensor_a.strides();
661 TORCH_CHECK((strides_a[0] == 1 || strides_a[1] == 1) &&
662 strides_a[0] != strides_a[1],
663 "_sparse_semi_structured_linear: Invalid strides for weight "
664 "argument: row stride = ", strides_a[0], ", column stride = ",
665 strides_a[1]);
666 TORCH_CHECK(tensor_b.layout() == Layout::Strided,
667 "_sparse_semi_structured_linear: Expected input argument "
668 "to be strided, but got layout ", tensor_b.layout());
669 TORCH_CHECK(tensor_b.dim() == 2,
670 "_sparse_semi_structured_linear: Expected input argument "
671 "to be 2D tensor, got ", tensor_b.dim(), " dims");
672 const auto strides_b = tensor_b.strides();
673 TORCH_CHECK((strides_b[0] == 1 || strides_b[1] == 1) &&
674 strides_b[0] != strides_b[1],
675 "_sparse_semi_structured_linear: Invalid strides for input "
676 "argument: row stride = ", strides_b[0], ", column stride = ",
677 strides_b[1]);
678 if (tensor_c.numel() != 0) {
679 TORCH_CHECK(tensor_c.layout() == Layout::Strided,
680 "_sparse_semi_structured_linear: Expected bias argument "
681 "to be strided, but got layout ", tensor_c.layout());
682 TORCH_CHECK(tensor_c.dim() == 1,
683 "_sparse_semi_structured_linear: Expected bias argument "
684 "to be 1D tensor, got ", tensor_c.dim(), " dims");
685 }
686
687 // Validate sizes of input tensors.
688 TORCH_CHECK(tensor_a.size(1) == tensor_b.size(0) / 2,
689 "_sparse_semi_structured_linear: Expected weight argument "
690 "to have ", tensor_b.size(0) / 2, " columns, but got ",
691 tensor_a.size(1));
692 if (tensor_c.numel() != 0) {
693 TORCH_CHECK(tensor_c.size(0) == tensor_a.size(0),
694 "_sparse_semi_structured_linear: Expected bias argument "
695 "to have ", tensor_a.size(0), " elements, but got ",
696 tensor_c.size(0));
697 }
698
699 // Call wrapper function for CUTLASS sparse GEMM, dispatching on
700 // the input datatype, and then on input tensors layouts.
701 // According to the input tensors datatypes and layouts,
702 // corresponding template arguments are supplied for instantiating
703 // the wrapper function. The tile sizes template arguments are
704 // selected according to the CUTLASS profiler results, for number
705 // of runs.
706 Tensor output;
707 AT_DISPATCH_SWITCH(
708 tensor_a.scalar_type(),
709 "_sparse_semi_structured_linear",
710 AT_DISPATCH_CASE(
711 at::ScalarType::Char,
712 [&]() {
713 using ElementInputA = int8_t;
714 using ElementInputB = int8_t;
715 using ElementAccumulator = int32_t;
716 using ThreadblockShape =
717 cutlass::gemm::GemmShape<128, 128, 128>;
718 using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
719 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
720 const auto EnableRowMajorRowMajorLayouts = false;
721 const auto EnableRowMajorColumnMajorLayouts = true;
722 const auto EnableColumnMajorRowMajorLayouts = false;
723 const auto EnableColumnMajorColumnMajorLayouts = false;
724 const auto EnableActivationNone = true;
725 const auto EnableActivationReLU = true;
726 const auto EnableActivationSiLU = false;
727 if (out_dtype_opt.has_value()) {
728 using ElementOutput = int32_t;
729 output = two_four_sgemm_dispatch_layouts_bias_activation<
730 ElementInputA,
731 ElementInputB,
732 ElementOutput,
733 ElementAccumulator,
734 ThreadblockShape,
735 WarpShape,
736 InstructionShape,
737 EnableRowMajorRowMajorLayouts,
738 EnableRowMajorColumnMajorLayouts,
739 EnableColumnMajorRowMajorLayouts,
740 EnableColumnMajorColumnMajorLayouts,
741 EnableActivationNone,
742 EnableActivationReLU,
743 EnableActivationSiLU>(
744 tensor_a,
745 tensor_b,
746 tensor_c,
747 meta,
748 activation);
749 } else {
750 using ElementOutput = int8_t;
751 output = two_four_sgemm_dispatch_layouts_bias_activation<
752 ElementInputA,
753 ElementInputB,
754 ElementOutput,
755 ElementAccumulator,
756 ThreadblockShape,
757 WarpShape,
758 InstructionShape,
759 EnableRowMajorRowMajorLayouts,
760 EnableRowMajorColumnMajorLayouts,
761 EnableColumnMajorRowMajorLayouts,
762 EnableColumnMajorColumnMajorLayouts,
763 EnableActivationNone,
764 EnableActivationReLU,
765 EnableActivationSiLU>(
766 tensor_a,
767 tensor_b,
768 tensor_c,
769 meta,
770 activation);
771 }
772 return;
773 })
774 AT_DISPATCH_CASE(
775 at::ScalarType::Half,
776 [&]() {
777 using ElementInputA = cutlass::half_t;
778 using ElementInputB = cutlass::half_t;
779 using ElementOutput = cutlass::half_t;
780 using ElementAccumulator = float;
781 using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
782 using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
783 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
784 const auto EnableRowMajorRowMajorLayouts = true;
785 const auto EnableRowMajorColumnMajorLayouts = true;
786 const auto EnableColumnMajorRowMajorLayouts = true;
787 const auto EnableColumnMajorColumnMajorLayouts = true;
788 const auto EnableActivationNone = true;
789 const auto EnableActivationReLU = true;
790 const auto EnableActivationSiLU = true;
791 output = two_four_sgemm_dispatch_layouts_bias_activation<
792 ElementInputA,
793 ElementInputB,
794 ElementOutput,
795 ElementAccumulator,
796 ThreadblockShape,
797 WarpShape,
798 InstructionShape,
799 EnableRowMajorRowMajorLayouts,
800 EnableRowMajorColumnMajorLayouts,
801 EnableColumnMajorRowMajorLayouts,
802 EnableColumnMajorColumnMajorLayouts,
803 EnableActivationNone,
804 EnableActivationReLU,
805 EnableActivationSiLU>(
806 tensor_a,
807 tensor_b,
808 tensor_c,
809 meta,
810 activation);
811 return;
812 })
813 AT_DISPATCH_CASE(
814 at::ScalarType::BFloat16,
815 [&]() {
816 using ElementInputA = cutlass::bfloat16_t;
817 using ElementInputB = cutlass::bfloat16_t;
818 using ElementOutput = cutlass::bfloat16_t;
819 using ElementAccumulator = float;
820 using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
821 using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
822 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
823 const auto EnableRowMajorRowMajorLayouts = true;
824 const auto EnableRowMajorColumnMajorLayouts = true;
825 const auto EnableColumnMajorRowMajorLayouts = true;
826 const auto EnableColumnMajorColumnMajorLayouts = true;
827 const auto EnableActivationNone = true;
828 const auto EnableActivationReLU = true;
829 const auto EnableActivationSiLU = true;
830 output = two_four_sgemm_dispatch_layouts_bias_activation<
831 ElementInputA,
832 ElementInputB,
833 ElementOutput,
834 ElementAccumulator,
835 ThreadblockShape,
836 WarpShape,
837 InstructionShape,
838 EnableRowMajorRowMajorLayouts,
839 EnableRowMajorColumnMajorLayouts,
840 EnableColumnMajorRowMajorLayouts,
841 EnableColumnMajorColumnMajorLayouts,
842 EnableActivationNone,
843 EnableActivationReLU,
844 EnableActivationSiLU>(
845 tensor_a,
846 tensor_b,
847 tensor_c,
848 meta,
849 activation);
850 return;
851 })
852 AT_DISPATCH_CASE(
853 at::ScalarType::Float,
854 [&]() {
855 using ElementInputA = float;
856 using ElementInputB = float;
857 using ElementOutput = float;
858 using ElementAccumulator = float;
859 using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
860 using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
861 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
862 const auto EnableRowMajorRowMajorLayouts = true;
863 const auto EnableRowMajorColumnMajorLayouts = true;
864 const auto EnableColumnMajorRowMajorLayouts = true;
865 const auto EnableColumnMajorColumnMajorLayouts = true;
866 const auto EnableActivationNone = true;
867 const auto EnableActivationReLU = true;
868 const auto EnableActivationSiLU = true;
869 output = two_four_sgemm_dispatch_layouts_bias_activation<
870 ElementInputA,
871 ElementInputB,
872 ElementOutput,
873 ElementAccumulator,
874 ThreadblockShape,
875 WarpShape,
876 InstructionShape,
877 EnableRowMajorRowMajorLayouts,
878 EnableRowMajorColumnMajorLayouts,
879 EnableColumnMajorRowMajorLayouts,
880 EnableColumnMajorColumnMajorLayouts,
881 EnableActivationNone,
882 EnableActivationReLU,
883 EnableActivationSiLU>(
884 tensor_a,
885 tensor_b,
886 tensor_c,
887 meta,
888 activation);
889 return;
890 }));
891
892 // Re-introduce batch dimensions into the output, and return.
893 auto output_sizes = input_sizes;
894 output_sizes.back() = weight.size(0);
895 return output.transpose(-1, -2).reshape(output_sizes);
896 #endif
897 }
898
899 } // namespace at::native
900