xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/core/Tensor.h>
3 #include <ATen/cuda/CUDAUtils.h>
4 #include <ATen/Dispatch.h>
5 
6 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
7 #else
8 #include <cuda_runtime.h>
9 #include <cutlass/cutlass.h>
10 #include <cutlass/layout/layout.h>
11 #include <cutlass/tensor_ref.h>
12 #include <cutlass/gemm/device/gemm_sparse_with_visitor.h>
13 #include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
14 #endif
15 
16 #include <type_traits>
17 #include <tuple>
18 
19 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
20 #else
21 #define CUTLASS_STATUS_CHECK(status)                                      \
22   {                                                                       \
23     TORCH_CHECK(status == cutlass::Status::kSuccess,                      \
24                 "Got CUTLASS error: ", cutlassGetStatusString(status));   \
25   }
26 
27 namespace {
28     enum class Activation{NONE, RELU, SILU};
29 }
30 #endif
31 
32 namespace at::native {
33 
34 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
35 #else
36 // Wrapper function for CUTLASS sparse GEMM implementation, used
37 // solely to simplify dispatching from
38 // _sparse_semi_structured_linear() function below.
39 template <
40     typename ElementInputA,
41     typename ElementInputB,
42     typename ElementOutput,
43     typename ElementAccumulator,
44     typename ThreadblockShape,
45     typename WarpShape,
46     typename InstructionShape,
47     typename LayoutInputA,
48     typename LayoutInputB,
49     bool use_bias,
50     Activation activation>
51 Tensor two_four_sgemm(
52     const Tensor& tensor_a,
53     const at::IntArrayRef::value_type& tensor_a_stride,
54     const Tensor& tensor_b,
55     const at::IntArrayRef::value_type& tensor_b_stride,
56     const Tensor& tensor_c, const Tensor& meta) {
57     // Fix CUTLASS sparse GEMM template arguments that are not
58     // provided as template argument of this function, and create an
59     // alias for particular instantiation of this template.
60     using LayoutOutput = cutlass::layout::RowMajor; // Result of the operation will be provided in row-major format.
61     using MMAOp = cutlass::arch::OpClassTensorOp; // Tensor cores are to be used for maximum performance.
62     using SmArch = cutlass::arch::Sm80; // Only CC 8.x devices are supported at the moment.
63     using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This choice provides good performance across wide range of operand sizes.
64     constexpr int NumStages = 3; // This choice provides good performance across wide range of operand sizes.
65     using Operator = cutlass::arch::OpMultiplyAdd;
66     constexpr int NumEVTEpilogueStages = 1;
67 
68     constexpr int AlignmentInputA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
69     constexpr int AlignmentInputB = 128 / cutlass::sizeof_bits<ElementInputB>::value;
70     constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
71 
72     using ElementComputeEpilogue = ElementAccumulator;
73     constexpr int AlignmentComputeEpilogue = 128 / cutlass::sizeof_bits<ElementComputeEpilogue>::value;
74     using ElementC = ElementOutput;
75     using LayoutC = LayoutOutput;
76     constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
77 
78     using BiasTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
79         ThreadblockShape,
80         WarpShape,
81         ElementC,
82         AlignmentC,
83         NumEVTEpilogueStages>;
84     using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
85         ThreadblockShape,
86         WarpShape,
87         ElementOutput,
88         AlignmentOutput,
89         NumEVTEpilogueStages>;
90 
91     using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
92 
93     using BiasScalar =
94         cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementC>;
95     using BiasTensor =
96         cutlass::epilogue::threadblock::VisitorColBroadcast<
97             BiasTileThreadMap,
98             ElementC,
99             cute::Stride<cute::_1, cute::_0, int64_t>>;
100     using Bias = std::conditional_t<use_bias, BiasTensor, BiasScalar>;
101     using BiasArguments = typename Bias::Arguments;
102 
103     using ApplyBias = cutlass::epilogue::threadblock::VisitorCompute<
104         cutlass::plus, ElementComputeEpilogue, ElementComputeEpilogue,
105         cutlass::FloatRoundStyle::round_to_nearest>;
106     using EVTApplyBias = cutlass::epilogue::threadblock::Sm80EVT<
107         ApplyBias,
108         Accum,
109         Bias>;
110 
111     using ApplyActivationNone = cutlass::epilogue::threadblock::VisitorCompute<
112         cutlass::epilogue::thread::Identity,
113         ElementComputeEpilogue,
114         ElementComputeEpilogue,
115         cutlass::FloatRoundStyle::round_to_nearest>;
116     using ApplyActivationReLu = cutlass::epilogue::threadblock::VisitorCompute<
117         cutlass::epilogue::thread::ReLu,
118         ElementComputeEpilogue,
119         ElementComputeEpilogue,
120         cutlass::FloatRoundStyle::round_to_nearest>;
121     using ApplyActivationSiLu = cutlass::epilogue::threadblock::VisitorCompute<
122         cutlass::epilogue::thread::SiLu,
123         ElementComputeEpilogue,
124         ElementComputeEpilogue,
125         cutlass::FloatRoundStyle::round_to_nearest>;
126     using ApplyActivation =
127         std::conditional_t<
128             activation == Activation::NONE,
129             ApplyActivationNone,
130             std::conditional_t<
131                 activation == Activation::RELU,
132                 ApplyActivationReLu,
133                 ApplyActivationSiLu>>;
134     using EVTApplyActivation = cutlass::epilogue::threadblock::Sm80EVT<
135         ApplyActivation,
136         EVTApplyBias>;
137 
138     using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
139         OutputTileThreadMap, ElementOutput, cutlass::FloatRoundStyle::round_to_nearest,
140         cute::Stride<int64_t, cute::_1, int64_t>>;
141 
142     using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<
143         Output,
144         EVTApplyActivation>;
145 
146     using Gemm = cutlass::gemm::device::SparseGemmWithVisitor<
147         ElementInputA,
148         LayoutInputA,
149         ElementInputB,
150         LayoutInputB,
151         ElementC,
152         LayoutC,
153         ElementAccumulator,
154         MMAOp,
155         SmArch,
156         ThreadblockShape,
157         WarpShape,
158         InstructionShape,
159         EVTOutput,
160         SwizzleThreadBlock,
161         NumStages,
162         AlignmentInputA,
163         AlignmentInputB,
164         Operator,
165         NumEVTEpilogueStages>;
166 
167     // Datatype and layout of metadata matrix are inferred from sparse
168     // GEMM template.
169     using ElementInputE = typename Gemm::ElementE;
170     using LayoutInputE = cutlass::layout::RowMajor;
171     using ReorderedLayoutInputE = typename Gemm::LayoutE;
172     static_assert(
173         std::is_same<ReorderedLayoutInputE,
174                      cutlass::layout::ColumnMajorInterleaved<2>>::value,
175         "Matrix layout used by CUTLASS for reordered metadata for sparse GEMM "
176         "change, thus code doing conversions from/to dense matrix has to be "
177         "updated.");
178 
179     constexpr auto kSparse = Gemm::kSparse;
180     constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
181 
182     // Operand sizes.
183     const int length_m = tensor_a.size(0);
184     const int length_k = tensor_b.size(0);
185     const int length_n = tensor_b.size(1);
186     const auto meta_ncols = length_k / kSparse / kElementsPerElementE;
187 
188     // Determine PyTorch datatype for the metadata matrix.
189     auto meta_dtype = at::kChar;
190     switch (sizeof(ElementInputE)) {
191     case 2:
192         meta_dtype = at::kShort;
193         break;
194     case 4:
195         meta_dtype = at::kInt;
196         break;
197     default:
198         AT_ERROR("two_four_sgemm: invalid size of meta tensor datatype "
199                  "encountered");
200     }
201     TORCH_CHECK(meta.dtype() == meta_dtype,
202                 "two_four_sgemm: Expected meta datatype ", meta_dtype,
203                 ", but got ", meta.dtype());
204 
205     // Determine PyTorch datatype for the output matrix.
206     auto tensor_d_dtype = at::kChar;
207     if constexpr (std::is_same_v<ElementOutput, int8_t>) {
208         tensor_d_dtype = at::kChar;
209     } else if constexpr (std::is_same_v<ElementOutput, int32_t>) {
210         tensor_d_dtype = at::kInt;
211     } else if constexpr (std::is_same_v<ElementOutput, cutlass::half_t>) {
212         tensor_d_dtype = at::kHalf;
213     } else if constexpr (std::is_same_v<ElementOutput, cutlass::bfloat16_t>) {
214         tensor_d_dtype = at::kBFloat16;
215     } else if constexpr (std::is_same_v<ElementOutput, float>) {
216         tensor_d_dtype = at::kFloat;
217     } else {
218         AT_ERROR("two_four_sgemm: invalid datatype for sparse GEMM output ",
219                  "encountered");
220     }
221     if constexpr (use_bias) {
222         TORCH_CHECK(tensor_c.dtype() == tensor_d_dtype,
223                     "two_four_sgemm: Expected sparse GEMM bias datatype ",
224                     tensor_d_dtype, ", but got ", tensor_c.dtype());
225     }
226 
227     // Create output matrix.
228     Tensor tensor_d =
229         tensor_a.new_empty({length_m, length_n},
230                            at::TensorOptions().dtype(tensor_d_dtype));
231 
232     // Prepare arguments for CUTLASS sparse GEMM kernel.
233     cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
234     LayoutInputA layout_a(tensor_a_stride);
235     LayoutInputB layout_b(tensor_b_stride);
236     auto tensor_a_device_ref =
237         cutlass::TensorRef<ElementInputA, LayoutInputA>(
238             (ElementInputA*)tensor_a.data_ptr(), layout_a);
239     auto tensor_b_device_ref =
240         cutlass::TensorRef<ElementInputB, LayoutInputB>(
241             (ElementInputB*)tensor_b.data_ptr(), layout_b);
242     auto tensor_e_reordered_device_ref =
243         cutlass::TensorRef<ElementInputE, ReorderedLayoutInputE>(
244             (ElementInputE*)meta.data_ptr(),
245             ReorderedLayoutInputE::packed({length_m, meta_ncols}));
246 
247     BiasArguments bias_arguments{
248         [&]() -> BiasArguments {
249             if constexpr (use_bias) {
250                 return {(ElementC*)tensor_c.data_ptr(),
251                         ElementC(0),
252                         {cute::_1{}, cute::_0{}, problem_size.m()}};
253             } else {
254                 return {ElementC(0)};
255             }
256         }()
257     };
258     typename Output::Arguments output_arguments{
259         (ElementOutput*)tensor_d.data_ptr(),
260         {problem_size.n(), cute::_1{}, problem_size.mn().product()}
261     };
262     typename EVTOutput::Arguments callback_arguments{
263         {
264             {
265                 {},                 // Accum
266                 bias_arguments,     // Bias
267                 {}                  // ApplyBias
268             },                      // EVTApplyBias
269             {}                      // ApplyActivation
270         },                          // EVTApplyActivation
271         output_arguments,           // Output
272     };                              // EVTOutput
273 
274     // Create a tuple of CUTLASS sparse GEMM kernel arguments.
275     typename Gemm::Arguments arguments{
276         problem_size,
277         tensor_a_device_ref,
278         tensor_b_device_ref,
279         tensor_e_reordered_device_ref,
280         callback_arguments};
281 
282     cutlass::Status status;
283 
284     // Create CUTLASS sparse GEMM kernel object.
285     Gemm gemm_op;
286 
287     // Verify that sparse GEMM operation with given arguments can be
288     // performed by CUTLASS.
289     status = gemm_op.can_implement(arguments);
290     CUTLASS_STATUS_CHECK(status);
291 
292     // Allocate workspace for CUTLASS sparse GEMM kernel.
293     const auto workspace_size = Gemm::get_workspace_size(arguments);
294     auto workspace = tensor_a.new_empty({(int64_t)workspace_size},
295                                         at::TensorOptions().dtype(at::kByte));
296 
297     // Initialize CUTLASS sparse GEMM object.
298     status = gemm_op.initialize(arguments, workspace.data_ptr(),
299                                 at::cuda::getCurrentCUDAStream());
300     CUTLASS_STATUS_CHECK(status);
301 
302     // Perform sparse GEMM operation.
303     status = gemm_op.run(at::cuda::getCurrentCUDAStream());
304     CUTLASS_STATUS_CHECK(status);
305 
306     C10_CUDA_KERNEL_LAUNCH_CHECK();
307 
308     return tensor_d;
309 }
310 
311 // Dispatch according to the input tensors layouts combination.
312 template <
313     typename ElementInputA,
314     typename ElementInputB,
315     typename ElementOutput,
316     typename ElementAccumulator,
317     typename ThreadblockShape,
318     typename WarpShape,
319     typename InstructionShape,
320     bool EnableRowMajorRowMajorLayouts,
321     bool EnableRowMajorColumnMajorLayouts,
322     bool EnableColumnMajorRowMajorLayouts,
323     bool EnableColumnMajorColumnMajorLayouts,
324     bool use_bias,
325     Activation activation>
326 Tensor two_four_sgemm_dispatch_layouts(
327     const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
328     const Tensor& meta) {
329     // Determine layouts (row-major or column-major) of input tensors.
330     const auto strides_a = tensor_a.strides();
331     auto tensor_a_row_major = strides_a[1] == 1;
332     auto tensor_a_stride = tensor_a_row_major ? strides_a[0] : strides_a[1];
333     const auto strides_b = tensor_b.strides();
334     auto tensor_b_row_major = strides_b[1] == 1;
335     auto tensor_b_stride = tensor_b_row_major ? strides_b[0] : strides_b[1];
336 
337     // Perform dispatching.
338     if constexpr (EnableRowMajorRowMajorLayouts) {
339         if (tensor_a_row_major && tensor_b_row_major) {
340             return two_four_sgemm<
341                 ElementInputA,
342                 ElementInputB,
343                 ElementOutput,
344                 ElementAccumulator,
345                 ThreadblockShape,
346                 WarpShape,
347                 InstructionShape,
348                 cutlass::layout::RowMajor,
349                 cutlass::layout::RowMajor,
350                 use_bias,
351                 activation>(
352                 tensor_a,
353                 tensor_a_stride,
354                 tensor_b,
355                 tensor_b_stride,
356                 tensor_c,
357                 meta);
358         }
359     }
360     if constexpr (EnableRowMajorColumnMajorLayouts) {
361         if (tensor_a_row_major && !tensor_b_row_major) {
362             return two_four_sgemm<
363                 ElementInputA,
364                 ElementInputB,
365                 ElementOutput,
366                 ElementAccumulator,
367                 ThreadblockShape,
368                 WarpShape,
369                 InstructionShape,
370                 cutlass::layout::RowMajor,
371                 cutlass::layout::ColumnMajor,
372                 use_bias,
373                 activation>(
374                 tensor_a,
375                 tensor_a_stride,
376                 tensor_b,
377                 tensor_b_stride,
378                 tensor_c,
379                 meta);
380         }
381     }
382     if constexpr (EnableColumnMajorRowMajorLayouts) {
383         if (!tensor_a_row_major && tensor_b_row_major) {
384             return two_four_sgemm<
385                 ElementInputA,
386                 ElementInputB,
387                 ElementOutput,
388                 ElementAccumulator,
389                 ThreadblockShape,
390                 WarpShape,
391                 InstructionShape,
392                 cutlass::layout::ColumnMajor,
393                 cutlass::layout::RowMajor,
394                 use_bias,
395                 activation>(
396                 tensor_a,
397                 tensor_a_stride,
398                 tensor_b,
399                 tensor_b_stride,
400                 tensor_c,
401                 meta);
402         }
403     }
404     if constexpr (EnableColumnMajorColumnMajorLayouts) {
405         if (!tensor_a_row_major && !tensor_b_row_major) {
406             return two_four_sgemm<
407                 ElementInputA,
408                 ElementInputB,
409                 ElementOutput,
410                 ElementAccumulator,
411                 ThreadblockShape,
412                 WarpShape,
413                 InstructionShape,
414                 cutlass::layout::ColumnMajor,
415                 cutlass::layout::ColumnMajor,
416                 use_bias,
417                 activation>(
418                 tensor_a,
419                 tensor_a_stride,
420                 tensor_b,
421                 tensor_b_stride,
422                 tensor_c,
423                 meta);
424         }
425     }
426 
427     AT_ERROR("two_four_sgemm_dispatch_layouts: Combination of ",
428              tensor_a_row_major ? "row-major" : "column_major", " and ",
429              tensor_b_row_major ? "row-major" : "column_major",
430              " layouts for input tensors is not supported");
431     return Tensor{};
432 }
433 
434 // Dispatch according to the bias tensor being provided or not.
435 template <
436     typename ElementInputA,
437     typename ElementInputB,
438     typename ElementOutput,
439     typename ElementAccumulator,
440     typename ThreadblockShape,
441     typename WarpShape,
442     typename InstructionShape,
443     bool EnableRowMajorRowMajorLayouts,
444     bool EnableRowMajorColumnMajorLayouts,
445     bool EnableColumnMajorRowMajorLayouts,
446     bool EnableColumnMajorColumnMajorLayouts,
447     Activation activation>
448 Tensor two_four_sgemm_dispatch_layouts_bias(
449     const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
450     const Tensor& meta) {
451     if (tensor_c.numel() > 0) {
452         return two_four_sgemm_dispatch_layouts<
453             ElementInputA,
454             ElementInputB,
455             ElementOutput,
456             ElementAccumulator,
457             ThreadblockShape,
458             WarpShape,
459             InstructionShape,
460             EnableRowMajorRowMajorLayouts,
461             EnableRowMajorColumnMajorLayouts,
462             EnableColumnMajorRowMajorLayouts,
463             EnableColumnMajorColumnMajorLayouts,
464             true,
465             activation>(
466             tensor_a,
467             tensor_b,
468             tensor_c,
469             meta);
470     } else {
471         return two_four_sgemm_dispatch_layouts<
472             ElementInputA,
473             ElementInputB,
474             ElementOutput,
475             ElementAccumulator,
476             ThreadblockShape,
477             WarpShape,
478             InstructionShape,
479             EnableRowMajorRowMajorLayouts,
480             EnableRowMajorColumnMajorLayouts,
481             EnableColumnMajorRowMajorLayouts,
482             EnableColumnMajorColumnMajorLayouts,
483             false,
484             activation>(
485             tensor_a,
486             tensor_b,
487             tensor_c,
488             meta);
489     }
490 }
491 
492 // Dispatch according to the activation functions enabled.
493 template <
494     typename ElementInputA,
495     typename ElementInputB,
496     typename ElementOutput,
497     typename ElementAccumulator,
498     typename ThreadblockShape,
499     typename WarpShape,
500     typename InstructionShape,
501     bool EnableRowMajorRowMajorLayouts,
502     bool EnableRowMajorColumnMajorLayouts,
503     bool EnableColumnMajorRowMajorLayouts,
504     bool EnableColumnMajorColumnMajorLayouts,
505     bool EnableActivationNone,
506     bool EnableActivationReLU,
507     bool EnableActivationSiLU>
508 Tensor two_four_sgemm_dispatch_layouts_bias_activation(
509     const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
510     const Tensor& meta, const c10::string_view& activation) {
511     // Perform dispatching.
512     if constexpr (EnableActivationNone) {
513         if (activation == "none") {
514             return two_four_sgemm_dispatch_layouts_bias<
515                 ElementInputA,
516                 ElementInputB,
517                 ElementOutput,
518                 ElementAccumulator,
519                 ThreadblockShape,
520                 WarpShape,
521                 InstructionShape,
522                 EnableRowMajorRowMajorLayouts,
523                 EnableRowMajorColumnMajorLayouts,
524                 EnableColumnMajorRowMajorLayouts,
525                 EnableColumnMajorColumnMajorLayouts,
526                 Activation::NONE>(
527                 tensor_a,
528                 tensor_b,
529                 tensor_c,
530                 meta);
531         }
532     }
533     if constexpr (EnableActivationReLU) {
534         if (activation == "relu") {
535             return two_four_sgemm_dispatch_layouts_bias<
536                 ElementInputA,
537                 ElementInputB,
538                 ElementOutput,
539                 ElementAccumulator,
540                 ThreadblockShape,
541                 WarpShape,
542                 InstructionShape,
543                 EnableRowMajorRowMajorLayouts,
544                 EnableRowMajorColumnMajorLayouts,
545                 EnableColumnMajorRowMajorLayouts,
546                 EnableColumnMajorColumnMajorLayouts,
547                 Activation::RELU>(
548                 tensor_a,
549                 tensor_b,
550                 tensor_c,
551                 meta);
552         }
553     }
554     if constexpr (EnableActivationSiLU) {
555         if (activation == "silu") {
556             return two_four_sgemm_dispatch_layouts_bias<
557                 ElementInputA,
558                 ElementInputB,
559                 ElementOutput,
560                 ElementAccumulator,
561                 ThreadblockShape,
562                 WarpShape,
563                 InstructionShape,
564                 EnableRowMajorRowMajorLayouts,
565                 EnableRowMajorColumnMajorLayouts,
566                 EnableColumnMajorRowMajorLayouts,
567                 EnableColumnMajorColumnMajorLayouts,
568                 Activation::SILU>(
569                 tensor_a,
570                 tensor_b,
571                 tensor_c,
572                 meta);
573         }
574     }
575 
576     AT_ERROR("two_four_sgemm_dispatch_layouts: Activation \"", activation,
577              "\" is not supported for given input tensors");
578     return Tensor{};
579 }
580 #endif
581 
582 // Perform linear transformation, but using corresponding CUTLASS
583 // sparse GEMM kernel, to given arguments:
584 //     output = input * weight.T + bias
585 // The "input" tensor is a dense tensor, while the "weight" tensor is
586 // a matrix with 2:4 sparsity pattern.  The "bias" tensor is optional;
587 // if provided, it should be a vector, with the number of elements
588 // equal to the number of rows of "weight" matrix.  It is assumed
589 // that.  It is assumed that "input", after squashing eventual batch
590 // dimensions with the next-to-last dimension of this tensor, and
591 // "weight" tensors are supplied either in row-major or column-major
592 // layouts (different layouts between these two tensors are OK, but
593 // not all combinations of formats are supported for some datatypes of
594 // these matrices).  The "meta" argument contains metadata matrix. The
595 // function returns the output tensor.
596 //
597 // There exists numerous limitations of CUTLASS sparse GEMM kernel,
598 // with regards to sizes and alignments of input tensors, their
599 // layouts and datatypes, and so on; this is the reason for large
600 // number of checks throughout the code.
_sparse_semi_structured_linear(const Tensor & input,const Tensor & weight,const Tensor & meta,const std::optional<Tensor> & bias_opt,const std::optional<c10::string_view> activation_opt,const std::optional<c10::ScalarType> out_dtype_opt)601 Tensor _sparse_semi_structured_linear(
602       const Tensor& input, const Tensor& weight,
603       const Tensor& meta, const std::optional<Tensor>& bias_opt,
604       const std::optional<c10::string_view> activation_opt,
605       const std::optional<c10::ScalarType> out_dtype_opt) {
606     TORCH_WARN_ONCE("_sparse_semi_structured_linear is deprecated and will be "
607                     "removed in a future PyTorch release.  Please use "
608                     "_sparse_semi_structured_mm/_sparse_semi_structured_addmm "
609                     "instead.");
610 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
611     AT_ERROR("_sparse_semi_structured_linear: CUTLASS not supported");
612     return Tensor{};
613 #else
614     // No need to check that all tensors are on CUDA device, as this
615     // is provided by dispatch.
616 
617     // Introduce alias names for arguments, according to the CUTLASS
618     // naming conventions.  Also, squash the batch dimensions of the
619     // input tensor with its next-to-last dimensions.
620     const auto input_sizes = input.sizes().vec();
621     const auto tensor_a = weight;
622     const auto tensor_b =
623         input.reshape({-1, input_sizes.back()}).transpose(-1, -2);
624     const auto tensor_c = bias_opt.has_value() ? *bias_opt : Tensor{};
625 
626     const auto activation =
627         activation_opt.has_value() ? *activation_opt : "none";
628 
629     TORCH_CHECK(!out_dtype_opt.has_value() ||
630                 (tensor_a.dtype() == at::ScalarType::Char &&
631                  out_dtype_opt.value() == at::ScalarType::Int),
632                 "_sparse_semi_structured_linear: Setting out_dtype is only "
633                 "supported for int8 input and int32 output");
634 
635     // For now, only CC 8.x devices are supported.
636     const auto dprops = at::cuda::getCurrentDeviceProperties();
637     const auto is_sm8x = dprops->major == 8;
638     TORCH_CHECK(is_sm8x,
639                 "_sparse_semi_structured_linear: Supported only on GPUs with "
640                 "compute capability 8.x");
641 
642     // Validate datatypes of input tensors.
643     TORCH_CHECK(tensor_a.dtype() == at::kChar ||
644                 tensor_a.dtype() == at::kHalf ||
645                 tensor_a.dtype() == at::kBFloat16 ||
646                 tensor_a.dtype() == at::kFloat,
647                 "_sparse_semi_structured_linear: The weight datatype ",
648                 tensor_a.dtype(), " is not supported");
649     TORCH_CHECK(tensor_b.dtype() == tensor_a.dtype(),
650                 "_sparse_semi_structured_linear: Expected input datatype ",
651                 tensor_a.dtype(), ", but got ", tensor_b.dtype());
652 
653     // Validate layouts of input tensors.
654     TORCH_CHECK(tensor_a.layout() == Layout::Strided,
655                 "_sparse_semi_structured_linear: Expected weight argument "
656                 "to be strided, but got layout ", tensor_a.layout());
657     TORCH_CHECK(tensor_a.dim() == 2,
658                 "_sparse_semi_structured_linear: Expected weight argument "
659                 "to be 2D tensor, got ", tensor_a.dim(), " dims");
660     const auto strides_a = tensor_a.strides();
661     TORCH_CHECK((strides_a[0] == 1 || strides_a[1] == 1) &&
662                 strides_a[0] != strides_a[1],
663                 "_sparse_semi_structured_linear: Invalid strides for weight "
664                 "argument: row stride = ", strides_a[0], ", column stride = ",
665                 strides_a[1]);
666     TORCH_CHECK(tensor_b.layout() == Layout::Strided,
667                 "_sparse_semi_structured_linear: Expected input argument "
668                 "to be strided, but got layout ", tensor_b.layout());
669     TORCH_CHECK(tensor_b.dim() == 2,
670                 "_sparse_semi_structured_linear: Expected input argument "
671                 "to be 2D tensor, got ", tensor_b.dim(), " dims");
672     const auto strides_b = tensor_b.strides();
673     TORCH_CHECK((strides_b[0] == 1 || strides_b[1] == 1) &&
674                 strides_b[0] != strides_b[1],
675                 "_sparse_semi_structured_linear: Invalid strides for input "
676                 "argument: row stride = ", strides_b[0], ", column stride = ",
677                 strides_b[1]);
678     if (tensor_c.numel() != 0) {
679         TORCH_CHECK(tensor_c.layout() == Layout::Strided,
680                     "_sparse_semi_structured_linear: Expected bias argument "
681                     "to be strided, but got layout ", tensor_c.layout());
682         TORCH_CHECK(tensor_c.dim() == 1,
683                     "_sparse_semi_structured_linear: Expected bias argument "
684                     "to be 1D tensor, got ", tensor_c.dim(), " dims");
685     }
686 
687     // Validate sizes of input tensors.
688     TORCH_CHECK(tensor_a.size(1) == tensor_b.size(0) / 2,
689                 "_sparse_semi_structured_linear: Expected weight argument "
690                 "to have ", tensor_b.size(0) / 2, " columns, but got ",
691                 tensor_a.size(1));
692     if (tensor_c.numel() != 0) {
693         TORCH_CHECK(tensor_c.size(0) == tensor_a.size(0),
694                     "_sparse_semi_structured_linear: Expected bias argument "
695                     "to have ", tensor_a.size(0), " elements, but got ",
696                     tensor_c.size(0));
697     }
698 
699     // Call wrapper function for CUTLASS sparse GEMM, dispatching on
700     // the input datatype, and then on input tensors layouts.
701     // According to the input tensors datatypes and layouts,
702     // corresponding template arguments are supplied for instantiating
703     // the wrapper function.  The tile sizes template arguments are
704     // selected according to the CUTLASS profiler results, for number
705     // of runs.
706     Tensor output;
707     AT_DISPATCH_SWITCH(
708         tensor_a.scalar_type(),
709         "_sparse_semi_structured_linear",
710         AT_DISPATCH_CASE(
711             at::ScalarType::Char,
712             [&]() {
713                 using ElementInputA = int8_t;
714                 using ElementInputB = int8_t;
715                 using ElementAccumulator = int32_t;
716                 using ThreadblockShape =
717                     cutlass::gemm::GemmShape<128, 128, 128>;
718                 using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
719                 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
720                 const auto EnableRowMajorRowMajorLayouts = false;
721                 const auto EnableRowMajorColumnMajorLayouts = true;
722                 const auto EnableColumnMajorRowMajorLayouts = false;
723                 const auto EnableColumnMajorColumnMajorLayouts = false;
724                 const auto EnableActivationNone = true;
725                 const auto EnableActivationReLU = true;
726                 const auto EnableActivationSiLU = false;
727                 if (out_dtype_opt.has_value()) {
728                   using ElementOutput = int32_t;
729                   output = two_four_sgemm_dispatch_layouts_bias_activation<
730                       ElementInputA,
731                       ElementInputB,
732                       ElementOutput,
733                       ElementAccumulator,
734                       ThreadblockShape,
735                       WarpShape,
736                       InstructionShape,
737                       EnableRowMajorRowMajorLayouts,
738                       EnableRowMajorColumnMajorLayouts,
739                       EnableColumnMajorRowMajorLayouts,
740                       EnableColumnMajorColumnMajorLayouts,
741                       EnableActivationNone,
742                       EnableActivationReLU,
743                       EnableActivationSiLU>(
744                       tensor_a,
745                       tensor_b,
746                       tensor_c,
747                       meta,
748                       activation);
749                 } else {
750                   using ElementOutput = int8_t;
751                   output = two_four_sgemm_dispatch_layouts_bias_activation<
752                       ElementInputA,
753                       ElementInputB,
754                       ElementOutput,
755                       ElementAccumulator,
756                       ThreadblockShape,
757                       WarpShape,
758                       InstructionShape,
759                       EnableRowMajorRowMajorLayouts,
760                       EnableRowMajorColumnMajorLayouts,
761                       EnableColumnMajorRowMajorLayouts,
762                       EnableColumnMajorColumnMajorLayouts,
763                       EnableActivationNone,
764                       EnableActivationReLU,
765                       EnableActivationSiLU>(
766                       tensor_a,
767                       tensor_b,
768                       tensor_c,
769                       meta,
770                       activation);
771                 }
772                 return;
773             })
774         AT_DISPATCH_CASE(
775             at::ScalarType::Half,
776             [&]() {
777                 using ElementInputA = cutlass::half_t;
778                 using ElementInputB = cutlass::half_t;
779                 using ElementOutput = cutlass::half_t;
780                 using ElementAccumulator = float;
781                 using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
782                 using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
783                 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
784                 const auto EnableRowMajorRowMajorLayouts = true;
785                 const auto EnableRowMajorColumnMajorLayouts = true;
786                 const auto EnableColumnMajorRowMajorLayouts = true;
787                 const auto EnableColumnMajorColumnMajorLayouts = true;
788                 const auto EnableActivationNone = true;
789                 const auto EnableActivationReLU = true;
790                 const auto EnableActivationSiLU = true;
791                 output = two_four_sgemm_dispatch_layouts_bias_activation<
792                     ElementInputA,
793                     ElementInputB,
794                     ElementOutput,
795                     ElementAccumulator,
796                     ThreadblockShape,
797                     WarpShape,
798                     InstructionShape,
799                     EnableRowMajorRowMajorLayouts,
800                     EnableRowMajorColumnMajorLayouts,
801                     EnableColumnMajorRowMajorLayouts,
802                     EnableColumnMajorColumnMajorLayouts,
803                     EnableActivationNone,
804                     EnableActivationReLU,
805                     EnableActivationSiLU>(
806                     tensor_a,
807                     tensor_b,
808                     tensor_c,
809                     meta,
810                     activation);
811                 return;
812             })
813             AT_DISPATCH_CASE(
814             at::ScalarType::BFloat16,
815             [&]() {
816                 using ElementInputA = cutlass::bfloat16_t;
817                 using ElementInputB = cutlass::bfloat16_t;
818                 using ElementOutput = cutlass::bfloat16_t;
819                 using ElementAccumulator = float;
820                 using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
821                 using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
822                 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
823                 const auto EnableRowMajorRowMajorLayouts = true;
824                 const auto EnableRowMajorColumnMajorLayouts = true;
825                 const auto EnableColumnMajorRowMajorLayouts = true;
826                 const auto EnableColumnMajorColumnMajorLayouts = true;
827                 const auto EnableActivationNone = true;
828                 const auto EnableActivationReLU = true;
829                 const auto EnableActivationSiLU = true;
830                 output = two_four_sgemm_dispatch_layouts_bias_activation<
831                     ElementInputA,
832                     ElementInputB,
833                     ElementOutput,
834                     ElementAccumulator,
835                     ThreadblockShape,
836                     WarpShape,
837                     InstructionShape,
838                     EnableRowMajorRowMajorLayouts,
839                     EnableRowMajorColumnMajorLayouts,
840                     EnableColumnMajorRowMajorLayouts,
841                     EnableColumnMajorColumnMajorLayouts,
842                     EnableActivationNone,
843                     EnableActivationReLU,
844                     EnableActivationSiLU>(
845                     tensor_a,
846                     tensor_b,
847                     tensor_c,
848                     meta,
849                     activation);
850                 return;
851             })
852             AT_DISPATCH_CASE(
853             at::ScalarType::Float,
854             [&]() {
855                 using ElementInputA = float;
856                 using ElementInputB = float;
857                 using ElementOutput = float;
858                 using ElementAccumulator = float;
859                 using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
860                 using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
861                 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
862                 const auto EnableRowMajorRowMajorLayouts = true;
863                 const auto EnableRowMajorColumnMajorLayouts = true;
864                 const auto EnableColumnMajorRowMajorLayouts = true;
865                 const auto EnableColumnMajorColumnMajorLayouts = true;
866                 const auto EnableActivationNone = true;
867                 const auto EnableActivationReLU = true;
868                 const auto EnableActivationSiLU = true;
869                 output = two_four_sgemm_dispatch_layouts_bias_activation<
870                     ElementInputA,
871                     ElementInputB,
872                     ElementOutput,
873                     ElementAccumulator,
874                     ThreadblockShape,
875                     WarpShape,
876                     InstructionShape,
877                     EnableRowMajorRowMajorLayouts,
878                     EnableRowMajorColumnMajorLayouts,
879                     EnableColumnMajorRowMajorLayouts,
880                     EnableColumnMajorColumnMajorLayouts,
881                     EnableActivationNone,
882                     EnableActivationReLU,
883                     EnableActivationSiLU>(
884                     tensor_a,
885                     tensor_b,
886                     tensor_c,
887                     meta,
888                     activation);
889                 return;
890             }));
891 
892     // Re-introduce batch dimensions into the output, and return.
893     auto output_sizes = input_sizes;
894     output_sizes.back() = weight.size(0);
895     return output.transpose(-1, -2).reshape(output_sizes);
896 #endif
897 }
898 
899 } // namespace at::native
900