xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/nested/cuda/NestedTensorMatmul.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <type_traits>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/Dispatch.h>
5 
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/detail/KernelUtils.h>
8 #include <ATen/cuda/detail/IndexUtils.cuh>
9 #include <ATen/native/cuda/Loops.cuh>
10 #include <ATen/native/cuda/MemoryAccess.cuh>
11 #include <ATen/native/cuda/PersistentSoftmax.cuh>
12 #include <ATen/native/cuda/block_reduce.cuh>
13 
14 #include <c10/cuda/CUDAMathCompat.h>
15 #include <c10/cuda/CUDAStream.h>
16 
17 #include <ATen/native/nested/NestedTensorTransformerFunctions.h>
18 #include <ATen/native/nested/NestedTensorUtils.h>
19 
20 #ifndef USE_ROCM
21 #ifndef _WIN32
22 #include <cutlass/gemm/device/default_gemm_configuration.h>
23 #include <cutlass/gemm/device/gemm_grouped.h>
24 #include <cutlass/gemm/kernel/default_gemm_grouped.h>
25 #endif
26 #endif
27 
28 #include <ATen/NestedTensorImpl.h>
29 
30 #define BLOCK_DIM 256
31 #define GRID_DIM_Y 16
32 
33 namespace at {
34 namespace native {
35 
36 #ifndef USE_ROCM
37 #ifndef _WIN32
38 namespace {
39 
40 template <
41     typename scalar_t,
42     unsigned int kPad,
43     typename LayoutA,
44     typename LayoutB,
45     typename OpClass,
46     typename Arch,
47     typename ThreadBlockShape,
48     typename WarpShape,
49     typename InstructionShape>
gemm_grouped_cuda_internal(const std::vector<int64_t> & lda,const std::vector<int64_t> & ldb,const std::vector<int64_t> & ldd,const std::vector<scalar_t * > & aptr,const std::vector<scalar_t * > & bptr,const std::vector<scalar_t * > & dptr,const std::vector<cutlass::gemm::GemmCoord> & gemm_sizes,const int problem_count,at::Device & device)50 void gemm_grouped_cuda_internal(
51     const std::vector<int64_t>& lda,
52     const std::vector<int64_t>& ldb,
53     const std::vector<int64_t>& ldd,
54     const std::vector<scalar_t*>& aptr,
55     const std::vector<scalar_t*>& bptr,
56     const std::vector<scalar_t*>& dptr,
57     const std::vector<cutlass::gemm::GemmCoord>& gemm_sizes,
58     const int problem_count,
59     at::Device& device) {
60   using Element = scalar_t;
61   using ElementAcc = float;
62 
63   using GemmConfiguration =
64       typename cutlass::gemm::device::DefaultGemmConfiguration<
65           OpClass,
66           Arch,
67           Element,
68           Element,
69           Element,
70           ElementAcc>;
71 
72   using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
73       Element,
74       LayoutA,
75       cutlass::ComplexTransform::kNone,
76       kPad,
77       Element,
78       LayoutB,
79       cutlass::ComplexTransform::kNone,
80       kPad,
81       Element,
82       cutlass::layout::RowMajor,
83       ElementAcc,
84       OpClass,
85       Arch,
86       ThreadBlockShape,
87       WarpShape,
88       InstructionShape,
89       typename GemmConfiguration::EpilogueOutputOp,
90       cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
91       GemmConfiguration::kStages>::GemmKernel;
92 
93   using GemmGrouped = typename cutlass::gemm::device::GemmGrouped<GemmKernel>;
94   using EpilogueOutputOp = typename GemmGrouped::GemmKernel::Epilogue::OutputOp;
95   typename EpilogueOutputOp::Params epilogue_op(/*alpha*/ 1, /*beta*/ 0);
96 
97   const int64_t gemm_coord_size =
98       problem_count * ((int64_t)sizeof(cutlass::gemm::GemmCoord));
99   // Number of gmm args not including *problem_sizes
100   at::Tensor gmm_args = at::empty(
101       {problem_count * 6 + gemm_coord_size},
102       at::TensorOptions().dtype(at::kLong).pinned_memory(true));
103 
104   // Obtain pointers for each argument (on host)
105   int64_t* lda_data = gmm_args.data_ptr<int64_t>(); // Base pointer
106   int64_t* ldb_data = lda_data + problem_count;
107   int64_t* ldd_data = lda_data + 2 * problem_count;
108   int64_t* ptr_a_data = lda_data + 3 * problem_count;
109   int64_t* ptr_b_data = lda_data + 4 * problem_count;
110   int64_t* ptr_d_data = lda_data + 5 * problem_count;
111   cutlass::gemm::GemmCoord* problem_sizes_data =
112       reinterpret_cast<cutlass::gemm::GemmCoord*>(lda_data + 6 * problem_count);
113 
114   // Set arguments into gmm_args from input args
115   for (int i = 0; i < problem_count; ++i) {
116     problem_sizes_data[i] = gemm_sizes[i];
117     lda_data[i] = lda[i];
118     ldb_data[i] = ldb[i];
119     ldd_data[i] = ldd[i];
120     ptr_a_data[i] = reinterpret_cast<int64_t>(aptr[i]);
121     ptr_b_data[i] = reinterpret_cast<int64_t>(bptr[i]);
122     ptr_d_data[i] = reinterpret_cast<int64_t>(dptr[i]);
123   }
124   const int threadblock_count =
125       GemmGrouped::sufficient(problem_sizes_data, problem_count);
126 
127   // Transfer arguments to GPU
128   gmm_args = gmm_args.to(device, true);
129 
130   // Obtain pointers for each of arguments (on GPU)
131   lda_data = gmm_args.data_ptr<int64_t>(); // Base pointer
132   ldb_data = lda_data + problem_count;
133   ldd_data = lda_data + 2 * problem_count;
134   ptr_a_data = lda_data + 3 * problem_count;
135   ptr_b_data = lda_data + 4 * problem_count;
136   ptr_d_data = lda_data + 5 * problem_count;
137   problem_sizes_data =
138       reinterpret_cast<cutlass::gemm::GemmCoord*>(lda_data + 6 * problem_count);
139 
140   // Create GemmGrouped::Arguments using the arguments prepared above
141   typename GemmGrouped::Arguments args(
142       problem_sizes_data,
143       problem_count,
144       threadblock_count,
145       epilogue_op,
146       reinterpret_cast<Element**>(ptr_a_data),
147       reinterpret_cast<Element**>(ptr_b_data),
148       reinterpret_cast<Element**>(ptr_d_data),
149       reinterpret_cast<Element**>(ptr_d_data),
150       lda_data,
151       ldb_data,
152       ldd_data,
153       ldd_data);
154 
155   GemmGrouped gemm;
156   cutlass::Status status =
157       gemm.initialize(args, nullptr, at::cuda::getCurrentCUDAStream());
158   TORCH_CHECK(
159       status != cutlass::Status::kErrorWorkspaceNull,
160       "Failed to initialize CUTLASS Grouped GEMM kernel due to workspace.");
161   TORCH_CHECK(
162       status != cutlass::Status::kErrorInternal,
163       "Failed to initialize CUTLASS Grouped GEMM kernel due to internal error.");
164   TORCH_CHECK(
165       status == cutlass::Status::kSuccess,
166       "Failed to initialize CUTLASS Grouped GEMM kernel.");
167 
168   // Run CUTLASS group GEMM
169   status = gemm.run(at::cuda::getCurrentCUDAStream());
170   TORCH_CHECK(
171       status == cutlass::Status::kSuccess,
172       "Failed to run CUTLASS Grouped GEMM kernel.");
173 
174   C10_CUDA_KERNEL_LAUNCH_CHECK();
175 }
176 
177 template <typename scalar_t>
group_gemm_dispatch(at::Device device,const std::vector<scalar_t * > & aptr,const std::vector<scalar_t * > & bptr,const std::vector<scalar_t * > & dptr,const std::vector<int64_t> & lda,const std::vector<int64_t> & ldb,const std::vector<int64_t> & ldd,std::vector<cutlass::gemm::GemmCoord> gemm_sizes,int64_t ntensors)178 bool group_gemm_dispatch(
179     at::Device device,
180     const std::vector<scalar_t*>& aptr,
181     const std::vector<scalar_t*>& bptr,
182     const std::vector<scalar_t*>& dptr,
183     const std::vector<int64_t>& lda,
184     const std::vector<int64_t>& ldb,
185     const std::vector<int64_t>& ldd,
186     std::vector<cutlass::gemm::GemmCoord> gemm_sizes,
187     int64_t ntensors) {
188   return false;
189 }
190 
191 template <>
group_gemm_dispatch(at::Device device,const std::vector<float * > & aptr,const std::vector<float * > & bptr,const std::vector<float * > & dptr,const std::vector<int64_t> & lda,const std::vector<int64_t> & ldb,const std::vector<int64_t> & ldd,std::vector<cutlass::gemm::GemmCoord> gemm_sizes,int64_t ntensors)192 bool group_gemm_dispatch(
193     at::Device device,
194     const std::vector<float*>& aptr,
195     const std::vector<float*>& bptr,
196     const std::vector<float*>& dptr,
197     const std::vector<int64_t>& lda,
198     const std::vector<int64_t>& ldb,
199     const std::vector<int64_t>& ldd,
200     std::vector<cutlass::gemm::GemmCoord> gemm_sizes,
201     int64_t ntensors) {
202 
203   gemm_grouped_cuda_internal<
204       float,
205       1,
206       cutlass::layout::RowMajor,
207       cutlass::layout::RowMajor,
208       cutlass::arch::OpClassSimt,
209       cutlass::arch::Sm80,
210       cutlass::gemm::GemmShape<128, 128, 8>,
211       cutlass::gemm::GemmShape<64, 32, 8>,
212       cutlass::gemm::GemmShape<1, 1, 1>>(
213       lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
214   return true;
215 }
216 
217 template <>
group_gemm_dispatch(at::Device device,const std::vector<c10::Half * > & aptr_,const std::vector<c10::Half * > & bptr_,const std::vector<c10::Half * > & dptr_,const std::vector<int64_t> & lda,const std::vector<int64_t> & ldb,const std::vector<int64_t> & ldd,std::vector<cutlass::gemm::GemmCoord> gemm_sizes,int64_t ntensors)218 bool group_gemm_dispatch(
219     at::Device device,
220     const std::vector<c10::Half*>& aptr_,
221     const std::vector<c10::Half*>& bptr_,
222     const std::vector<c10::Half*>& dptr_,
223     const std::vector<int64_t>& lda,
224     const std::vector<int64_t>& ldb,
225     const std::vector<int64_t>& ldd,
226     std::vector<cutlass::gemm::GemmCoord> gemm_sizes,
227     int64_t ntensors) {
228 
229   // Check alignment
230   bool all_pad_8 = true;
231   for (int i = 0; i < ntensors; i++) {
232     all_pad_8 = all_pad_8 && (gemm_sizes[i].n() % 8 == 0);
233     all_pad_8 = all_pad_8 && (gemm_sizes[i].k() % 8 == 0);
234 
235     // Not sure if this is a requirement, on the safe side
236     all_pad_8 = all_pad_8 && (lda[i] % 8 == 0);
237     all_pad_8 = all_pad_8 && (ldb[i] % 8 == 0);
238     all_pad_8 = all_pad_8 && (ldd[i] % 8 == 0);
239   }
240 
241   std::vector<cutlass::half_t*> aptr;
242   std::vector<cutlass::half_t*> bptr;
243   std::vector<cutlass::half_t*> dptr;
244   for (int64_t i = 0; i < ntensors; i++) {
245     aptr.push_back(reinterpret_cast<cutlass::half_t*>(aptr_[i]));
246     bptr.push_back(reinterpret_cast<cutlass::half_t*>(bptr_[i]));
247     dptr.push_back(reinterpret_cast<cutlass::half_t*>(dptr_[i]));
248   }
249   if (all_pad_8) {
250     gemm_grouped_cuda_internal<
251         cutlass::half_t,
252         8,
253         cutlass::layout::RowMajor,
254         cutlass::layout::RowMajor,
255         cutlass::arch::OpClassTensorOp,
256         cutlass::arch::Sm80,
257         cutlass::gemm::GemmShape<128, 128, 32>,
258         cutlass::gemm::GemmShape<64, 64, 32>,
259         cutlass::gemm::GemmShape<16, 8, 16>>(
260         lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
261     return true;
262   } else {
263     gemm_grouped_cuda_internal<
264         cutlass::half_t,
265         1,
266         cutlass::layout::RowMajor,
267         cutlass::layout::RowMajor,
268         cutlass::arch::OpClassSimt,
269         cutlass::arch::Sm80,
270         cutlass::gemm::GemmShape<128, 128, 8>,
271         cutlass::gemm::GemmShape<64, 32, 8>,
272         cutlass::gemm::GemmShape<1, 1, 1>>(
273         lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
274     return true;
275   }
276   // Did not perform GEMM
277   return false;
278 }
279 
280 } // namespace
281 
282 #endif
283 #endif
284 
bmm_nested_cuda(const Tensor & self,const Tensor & mat2)285 Tensor bmm_nested_cuda(const Tensor& self, const Tensor& mat2) {
286 
287   // dispatcher should have guaranteed that at least one is nested
288   auto self_ptr = self.is_nested() ? get_nested_tensor_impl(self) : self.unsafeGetTensorImpl();
289   auto mat2_ptr = mat2.is_nested() ? get_nested_tensor_impl(mat2) : mat2.unsafeGetTensorImpl();
290   TORCH_CHECK(self_ptr->dim() == 3, "batch1 must be a 3D tensor");
291   TORCH_CHECK(mat2_ptr->dim() == 3, "batch2 must be a 3D tensor");
292   int64_t ntensors = self_ptr->size(0), ntensors2 = mat2_ptr->size(0);
293   TORCH_CHECK(
294       ntensors == ntensors2,
295       "Expected size for the 1st dimension of batch2 tensor to be: ",
296       ntensors,
297       " but got: ",
298       ntensors2,
299       ".");
300 
301   // create a contiguous output
302   const Tensor& self_sizemat = self.is_nested() ?
303       get_nested_tensor_impl(self)->get_nested_sizes() : get_nested_tensor_impl(mat2)->get_nested_sizes();
304 
305   Tensor out_sizemat = self_sizemat.new_empty(self_sizemat.sizes());
306   int64_t* out_sizemat_ptr = out_sizemat.data_ptr<int64_t>();
307 
308   int64_t out_numel = 0;
309   for (int64_t i = 0; i < ntensors; i++) {
310     const IntArrayRef &self_shape = get_size_for_index(self, i), &mat2_shape = get_size_for_index(mat2, i);
311     const int64_t &self_size0 = self_shape[0], &self_size1 = self_shape[1],
312                   &mat2_size0 = mat2_shape[0], &mat2_size1 = mat2_shape[1];
313     TORCH_CHECK(
314         self_size1 == mat2_size0,
315         i,
316         "-th nested matrices in batch cannot be multiplied (",
317         self_size0,
318         "x",
319         self_size1,
320         " and ",
321         mat2_size0,
322         "x",
323         mat2_size1,
324         ")");
325     out_sizemat_ptr[0] = self_size0;
326     out_sizemat_ptr[1] = mat2_size1;
327     out_sizemat_ptr += 2;
328     out_numel += self_size0 * mat2_size1;
329   }
330 
331   const Tensor &self_buffer = self.is_nested() ? get_nested_tensor_impl(self)->get_unsafe_storage_as_tensor() : self;
332   const Tensor &mat2_buffer = mat2.is_nested() ? get_nested_tensor_impl(mat2)->get_unsafe_storage_as_tensor() : mat2;
333 
334   Tensor out_buffer = self_buffer.new_empty(out_numel);
335   Tensor output = wrap_buffer(out_buffer, out_sizemat);
336   auto out_ptr = get_nested_tensor_impl(output);
337 
338   const int64_t *out_offsets_ptr = out_ptr->get_storage_offsets().const_data_ptr<int64_t>();
339 
340 #ifndef USE_ROCM
341 #ifndef _WIN32
342   bool success = false;
343   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
344       self.scalar_type(), "group_gemm_dispatch", [&] {
345         std::vector<scalar_t*> aptr(ntensors);
346         std::vector<scalar_t*> bptr(ntensors);
347         std::vector<scalar_t*> dptr(ntensors);
348         std::vector<int64_t> lda(ntensors);
349         std::vector<int64_t> ldb(ntensors);
350         std::vector<int64_t> ldd(ntensors);
351         std::vector<cutlass::gemm::GemmCoord> gemm_sizes;
352         bool all_row_major = true;
353         for (int64_t i = 0; i < ntensors; i++) {
354           const IntArrayRef& self_shape = get_size_for_index(self, i);
355           const IntArrayRef& mat2_shape = get_size_for_index(mat2, i);
356           const int64_t &self_size0 = self_shape[0];
357           const int64_t &self_size1 = self_shape[1];
358           const int64_t &mat2_size0 = mat2_shape[0];
359           const int64_t &mat2_size1 = mat2_shape[1];
360           gemm_sizes.push_back(
361               cutlass::gemm::GemmCoord(self_size0, mat2_size1, self_size1));
362           aptr[i] = self_buffer.data_ptr<scalar_t>() + get_offset_for_index(self, i);
363           bptr[i] = mat2_buffer.data_ptr<scalar_t>() + get_offset_for_index(mat2, i);
364           dptr[i] = out_buffer.data_ptr<scalar_t>() + out_offsets_ptr[i];
365           auto self_stride = get_stride_for_index(self, i);
366           auto mat2_stride = get_stride_for_index(mat2, i);
367           all_row_major = all_row_major && (self_stride[1] == 1);
368           all_row_major = all_row_major && (mat2_stride[1] == 1);
369           lda[i] = self_stride[0];
370           ldb[i] = mat2_stride[0];
371           ldd[i] = mat2_size1;
372         }
373         auto dprops = at::cuda::getCurrentDeviceProperties();
374         bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
375         if (all_row_major &&
376             self.is_contiguous() &&
377             mat2.is_contiguous() &&
378             is_sm8x) {
379           success = group_gemm_dispatch<scalar_t>(
380               output.device(),
381               aptr,
382               bptr,
383               dptr,
384               lda,
385               ldb,
386               ldd,
387               gemm_sizes,
388               ntensors);
389         }
390       });
391   if (success) {
392     return output;
393   }
394 #endif
395 #endif
396 
397   std::vector<Tensor> output_unbind = output.unbind();
398   for (int64_t i = 0; i < ntensors; i++) {
399     at::mm_out(output_unbind[i],
400         self_buffer.as_strided(get_size_for_index(self, i), get_stride_for_index(self, i), get_offset_for_index(self, i)),
401         mat2_buffer.as_strided(get_size_for_index(mat2, i), get_stride_for_index(mat2, i), get_offset_for_index(mat2, i)));
402   }
403   return output;
404 }
405 
406 } // namespace native
407 } // namespace at
408