1 #include <type_traits>
2
3 #include <ATen/ATen.h>
4 #include <ATen/Dispatch.h>
5
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/detail/KernelUtils.h>
8 #include <ATen/cuda/detail/IndexUtils.cuh>
9 #include <ATen/native/cuda/Loops.cuh>
10 #include <ATen/native/cuda/MemoryAccess.cuh>
11 #include <ATen/native/cuda/PersistentSoftmax.cuh>
12 #include <ATen/native/cuda/block_reduce.cuh>
13
14 #include <c10/cuda/CUDAMathCompat.h>
15 #include <c10/cuda/CUDAStream.h>
16
17 #include <ATen/native/nested/NestedTensorTransformerFunctions.h>
18 #include <ATen/native/nested/NestedTensorUtils.h>
19
20 #ifndef USE_ROCM
21 #ifndef _WIN32
22 #include <cutlass/gemm/device/default_gemm_configuration.h>
23 #include <cutlass/gemm/device/gemm_grouped.h>
24 #include <cutlass/gemm/kernel/default_gemm_grouped.h>
25 #endif
26 #endif
27
28 #include <ATen/NestedTensorImpl.h>
29
30 #define BLOCK_DIM 256
31 #define GRID_DIM_Y 16
32
33 namespace at {
34 namespace native {
35
36 #ifndef USE_ROCM
37 #ifndef _WIN32
38 namespace {
39
40 template <
41 typename scalar_t,
42 unsigned int kPad,
43 typename LayoutA,
44 typename LayoutB,
45 typename OpClass,
46 typename Arch,
47 typename ThreadBlockShape,
48 typename WarpShape,
49 typename InstructionShape>
gemm_grouped_cuda_internal(const std::vector<int64_t> & lda,const std::vector<int64_t> & ldb,const std::vector<int64_t> & ldd,const std::vector<scalar_t * > & aptr,const std::vector<scalar_t * > & bptr,const std::vector<scalar_t * > & dptr,const std::vector<cutlass::gemm::GemmCoord> & gemm_sizes,const int problem_count,at::Device & device)50 void gemm_grouped_cuda_internal(
51 const std::vector<int64_t>& lda,
52 const std::vector<int64_t>& ldb,
53 const std::vector<int64_t>& ldd,
54 const std::vector<scalar_t*>& aptr,
55 const std::vector<scalar_t*>& bptr,
56 const std::vector<scalar_t*>& dptr,
57 const std::vector<cutlass::gemm::GemmCoord>& gemm_sizes,
58 const int problem_count,
59 at::Device& device) {
60 using Element = scalar_t;
61 using ElementAcc = float;
62
63 using GemmConfiguration =
64 typename cutlass::gemm::device::DefaultGemmConfiguration<
65 OpClass,
66 Arch,
67 Element,
68 Element,
69 Element,
70 ElementAcc>;
71
72 using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
73 Element,
74 LayoutA,
75 cutlass::ComplexTransform::kNone,
76 kPad,
77 Element,
78 LayoutB,
79 cutlass::ComplexTransform::kNone,
80 kPad,
81 Element,
82 cutlass::layout::RowMajor,
83 ElementAcc,
84 OpClass,
85 Arch,
86 ThreadBlockShape,
87 WarpShape,
88 InstructionShape,
89 typename GemmConfiguration::EpilogueOutputOp,
90 cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
91 GemmConfiguration::kStages>::GemmKernel;
92
93 using GemmGrouped = typename cutlass::gemm::device::GemmGrouped<GemmKernel>;
94 using EpilogueOutputOp = typename GemmGrouped::GemmKernel::Epilogue::OutputOp;
95 typename EpilogueOutputOp::Params epilogue_op(/*alpha*/ 1, /*beta*/ 0);
96
97 const int64_t gemm_coord_size =
98 problem_count * ((int64_t)sizeof(cutlass::gemm::GemmCoord));
99 // Number of gmm args not including *problem_sizes
100 at::Tensor gmm_args = at::empty(
101 {problem_count * 6 + gemm_coord_size},
102 at::TensorOptions().dtype(at::kLong).pinned_memory(true));
103
104 // Obtain pointers for each argument (on host)
105 int64_t* lda_data = gmm_args.data_ptr<int64_t>(); // Base pointer
106 int64_t* ldb_data = lda_data + problem_count;
107 int64_t* ldd_data = lda_data + 2 * problem_count;
108 int64_t* ptr_a_data = lda_data + 3 * problem_count;
109 int64_t* ptr_b_data = lda_data + 4 * problem_count;
110 int64_t* ptr_d_data = lda_data + 5 * problem_count;
111 cutlass::gemm::GemmCoord* problem_sizes_data =
112 reinterpret_cast<cutlass::gemm::GemmCoord*>(lda_data + 6 * problem_count);
113
114 // Set arguments into gmm_args from input args
115 for (int i = 0; i < problem_count; ++i) {
116 problem_sizes_data[i] = gemm_sizes[i];
117 lda_data[i] = lda[i];
118 ldb_data[i] = ldb[i];
119 ldd_data[i] = ldd[i];
120 ptr_a_data[i] = reinterpret_cast<int64_t>(aptr[i]);
121 ptr_b_data[i] = reinterpret_cast<int64_t>(bptr[i]);
122 ptr_d_data[i] = reinterpret_cast<int64_t>(dptr[i]);
123 }
124 const int threadblock_count =
125 GemmGrouped::sufficient(problem_sizes_data, problem_count);
126
127 // Transfer arguments to GPU
128 gmm_args = gmm_args.to(device, true);
129
130 // Obtain pointers for each of arguments (on GPU)
131 lda_data = gmm_args.data_ptr<int64_t>(); // Base pointer
132 ldb_data = lda_data + problem_count;
133 ldd_data = lda_data + 2 * problem_count;
134 ptr_a_data = lda_data + 3 * problem_count;
135 ptr_b_data = lda_data + 4 * problem_count;
136 ptr_d_data = lda_data + 5 * problem_count;
137 problem_sizes_data =
138 reinterpret_cast<cutlass::gemm::GemmCoord*>(lda_data + 6 * problem_count);
139
140 // Create GemmGrouped::Arguments using the arguments prepared above
141 typename GemmGrouped::Arguments args(
142 problem_sizes_data,
143 problem_count,
144 threadblock_count,
145 epilogue_op,
146 reinterpret_cast<Element**>(ptr_a_data),
147 reinterpret_cast<Element**>(ptr_b_data),
148 reinterpret_cast<Element**>(ptr_d_data),
149 reinterpret_cast<Element**>(ptr_d_data),
150 lda_data,
151 ldb_data,
152 ldd_data,
153 ldd_data);
154
155 GemmGrouped gemm;
156 cutlass::Status status =
157 gemm.initialize(args, nullptr, at::cuda::getCurrentCUDAStream());
158 TORCH_CHECK(
159 status != cutlass::Status::kErrorWorkspaceNull,
160 "Failed to initialize CUTLASS Grouped GEMM kernel due to workspace.");
161 TORCH_CHECK(
162 status != cutlass::Status::kErrorInternal,
163 "Failed to initialize CUTLASS Grouped GEMM kernel due to internal error.");
164 TORCH_CHECK(
165 status == cutlass::Status::kSuccess,
166 "Failed to initialize CUTLASS Grouped GEMM kernel.");
167
168 // Run CUTLASS group GEMM
169 status = gemm.run(at::cuda::getCurrentCUDAStream());
170 TORCH_CHECK(
171 status == cutlass::Status::kSuccess,
172 "Failed to run CUTLASS Grouped GEMM kernel.");
173
174 C10_CUDA_KERNEL_LAUNCH_CHECK();
175 }
176
177 template <typename scalar_t>
group_gemm_dispatch(at::Device device,const std::vector<scalar_t * > & aptr,const std::vector<scalar_t * > & bptr,const std::vector<scalar_t * > & dptr,const std::vector<int64_t> & lda,const std::vector<int64_t> & ldb,const std::vector<int64_t> & ldd,std::vector<cutlass::gemm::GemmCoord> gemm_sizes,int64_t ntensors)178 bool group_gemm_dispatch(
179 at::Device device,
180 const std::vector<scalar_t*>& aptr,
181 const std::vector<scalar_t*>& bptr,
182 const std::vector<scalar_t*>& dptr,
183 const std::vector<int64_t>& lda,
184 const std::vector<int64_t>& ldb,
185 const std::vector<int64_t>& ldd,
186 std::vector<cutlass::gemm::GemmCoord> gemm_sizes,
187 int64_t ntensors) {
188 return false;
189 }
190
191 template <>
group_gemm_dispatch(at::Device device,const std::vector<float * > & aptr,const std::vector<float * > & bptr,const std::vector<float * > & dptr,const std::vector<int64_t> & lda,const std::vector<int64_t> & ldb,const std::vector<int64_t> & ldd,std::vector<cutlass::gemm::GemmCoord> gemm_sizes,int64_t ntensors)192 bool group_gemm_dispatch(
193 at::Device device,
194 const std::vector<float*>& aptr,
195 const std::vector<float*>& bptr,
196 const std::vector<float*>& dptr,
197 const std::vector<int64_t>& lda,
198 const std::vector<int64_t>& ldb,
199 const std::vector<int64_t>& ldd,
200 std::vector<cutlass::gemm::GemmCoord> gemm_sizes,
201 int64_t ntensors) {
202
203 gemm_grouped_cuda_internal<
204 float,
205 1,
206 cutlass::layout::RowMajor,
207 cutlass::layout::RowMajor,
208 cutlass::arch::OpClassSimt,
209 cutlass::arch::Sm80,
210 cutlass::gemm::GemmShape<128, 128, 8>,
211 cutlass::gemm::GemmShape<64, 32, 8>,
212 cutlass::gemm::GemmShape<1, 1, 1>>(
213 lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
214 return true;
215 }
216
217 template <>
group_gemm_dispatch(at::Device device,const std::vector<c10::Half * > & aptr_,const std::vector<c10::Half * > & bptr_,const std::vector<c10::Half * > & dptr_,const std::vector<int64_t> & lda,const std::vector<int64_t> & ldb,const std::vector<int64_t> & ldd,std::vector<cutlass::gemm::GemmCoord> gemm_sizes,int64_t ntensors)218 bool group_gemm_dispatch(
219 at::Device device,
220 const std::vector<c10::Half*>& aptr_,
221 const std::vector<c10::Half*>& bptr_,
222 const std::vector<c10::Half*>& dptr_,
223 const std::vector<int64_t>& lda,
224 const std::vector<int64_t>& ldb,
225 const std::vector<int64_t>& ldd,
226 std::vector<cutlass::gemm::GemmCoord> gemm_sizes,
227 int64_t ntensors) {
228
229 // Check alignment
230 bool all_pad_8 = true;
231 for (int i = 0; i < ntensors; i++) {
232 all_pad_8 = all_pad_8 && (gemm_sizes[i].n() % 8 == 0);
233 all_pad_8 = all_pad_8 && (gemm_sizes[i].k() % 8 == 0);
234
235 // Not sure if this is a requirement, on the safe side
236 all_pad_8 = all_pad_8 && (lda[i] % 8 == 0);
237 all_pad_8 = all_pad_8 && (ldb[i] % 8 == 0);
238 all_pad_8 = all_pad_8 && (ldd[i] % 8 == 0);
239 }
240
241 std::vector<cutlass::half_t*> aptr;
242 std::vector<cutlass::half_t*> bptr;
243 std::vector<cutlass::half_t*> dptr;
244 for (int64_t i = 0; i < ntensors; i++) {
245 aptr.push_back(reinterpret_cast<cutlass::half_t*>(aptr_[i]));
246 bptr.push_back(reinterpret_cast<cutlass::half_t*>(bptr_[i]));
247 dptr.push_back(reinterpret_cast<cutlass::half_t*>(dptr_[i]));
248 }
249 if (all_pad_8) {
250 gemm_grouped_cuda_internal<
251 cutlass::half_t,
252 8,
253 cutlass::layout::RowMajor,
254 cutlass::layout::RowMajor,
255 cutlass::arch::OpClassTensorOp,
256 cutlass::arch::Sm80,
257 cutlass::gemm::GemmShape<128, 128, 32>,
258 cutlass::gemm::GemmShape<64, 64, 32>,
259 cutlass::gemm::GemmShape<16, 8, 16>>(
260 lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
261 return true;
262 } else {
263 gemm_grouped_cuda_internal<
264 cutlass::half_t,
265 1,
266 cutlass::layout::RowMajor,
267 cutlass::layout::RowMajor,
268 cutlass::arch::OpClassSimt,
269 cutlass::arch::Sm80,
270 cutlass::gemm::GemmShape<128, 128, 8>,
271 cutlass::gemm::GemmShape<64, 32, 8>,
272 cutlass::gemm::GemmShape<1, 1, 1>>(
273 lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
274 return true;
275 }
276 // Did not perform GEMM
277 return false;
278 }
279
280 } // namespace
281
282 #endif
283 #endif
284
bmm_nested_cuda(const Tensor & self,const Tensor & mat2)285 Tensor bmm_nested_cuda(const Tensor& self, const Tensor& mat2) {
286
287 // dispatcher should have guaranteed that at least one is nested
288 auto self_ptr = self.is_nested() ? get_nested_tensor_impl(self) : self.unsafeGetTensorImpl();
289 auto mat2_ptr = mat2.is_nested() ? get_nested_tensor_impl(mat2) : mat2.unsafeGetTensorImpl();
290 TORCH_CHECK(self_ptr->dim() == 3, "batch1 must be a 3D tensor");
291 TORCH_CHECK(mat2_ptr->dim() == 3, "batch2 must be a 3D tensor");
292 int64_t ntensors = self_ptr->size(0), ntensors2 = mat2_ptr->size(0);
293 TORCH_CHECK(
294 ntensors == ntensors2,
295 "Expected size for the 1st dimension of batch2 tensor to be: ",
296 ntensors,
297 " but got: ",
298 ntensors2,
299 ".");
300
301 // create a contiguous output
302 const Tensor& self_sizemat = self.is_nested() ?
303 get_nested_tensor_impl(self)->get_nested_sizes() : get_nested_tensor_impl(mat2)->get_nested_sizes();
304
305 Tensor out_sizemat = self_sizemat.new_empty(self_sizemat.sizes());
306 int64_t* out_sizemat_ptr = out_sizemat.data_ptr<int64_t>();
307
308 int64_t out_numel = 0;
309 for (int64_t i = 0; i < ntensors; i++) {
310 const IntArrayRef &self_shape = get_size_for_index(self, i), &mat2_shape = get_size_for_index(mat2, i);
311 const int64_t &self_size0 = self_shape[0], &self_size1 = self_shape[1],
312 &mat2_size0 = mat2_shape[0], &mat2_size1 = mat2_shape[1];
313 TORCH_CHECK(
314 self_size1 == mat2_size0,
315 i,
316 "-th nested matrices in batch cannot be multiplied (",
317 self_size0,
318 "x",
319 self_size1,
320 " and ",
321 mat2_size0,
322 "x",
323 mat2_size1,
324 ")");
325 out_sizemat_ptr[0] = self_size0;
326 out_sizemat_ptr[1] = mat2_size1;
327 out_sizemat_ptr += 2;
328 out_numel += self_size0 * mat2_size1;
329 }
330
331 const Tensor &self_buffer = self.is_nested() ? get_nested_tensor_impl(self)->get_unsafe_storage_as_tensor() : self;
332 const Tensor &mat2_buffer = mat2.is_nested() ? get_nested_tensor_impl(mat2)->get_unsafe_storage_as_tensor() : mat2;
333
334 Tensor out_buffer = self_buffer.new_empty(out_numel);
335 Tensor output = wrap_buffer(out_buffer, out_sizemat);
336 auto out_ptr = get_nested_tensor_impl(output);
337
338 const int64_t *out_offsets_ptr = out_ptr->get_storage_offsets().const_data_ptr<int64_t>();
339
340 #ifndef USE_ROCM
341 #ifndef _WIN32
342 bool success = false;
343 AT_DISPATCH_FLOATING_TYPES_AND_HALF(
344 self.scalar_type(), "group_gemm_dispatch", [&] {
345 std::vector<scalar_t*> aptr(ntensors);
346 std::vector<scalar_t*> bptr(ntensors);
347 std::vector<scalar_t*> dptr(ntensors);
348 std::vector<int64_t> lda(ntensors);
349 std::vector<int64_t> ldb(ntensors);
350 std::vector<int64_t> ldd(ntensors);
351 std::vector<cutlass::gemm::GemmCoord> gemm_sizes;
352 bool all_row_major = true;
353 for (int64_t i = 0; i < ntensors; i++) {
354 const IntArrayRef& self_shape = get_size_for_index(self, i);
355 const IntArrayRef& mat2_shape = get_size_for_index(mat2, i);
356 const int64_t &self_size0 = self_shape[0];
357 const int64_t &self_size1 = self_shape[1];
358 const int64_t &mat2_size0 = mat2_shape[0];
359 const int64_t &mat2_size1 = mat2_shape[1];
360 gemm_sizes.push_back(
361 cutlass::gemm::GemmCoord(self_size0, mat2_size1, self_size1));
362 aptr[i] = self_buffer.data_ptr<scalar_t>() + get_offset_for_index(self, i);
363 bptr[i] = mat2_buffer.data_ptr<scalar_t>() + get_offset_for_index(mat2, i);
364 dptr[i] = out_buffer.data_ptr<scalar_t>() + out_offsets_ptr[i];
365 auto self_stride = get_stride_for_index(self, i);
366 auto mat2_stride = get_stride_for_index(mat2, i);
367 all_row_major = all_row_major && (self_stride[1] == 1);
368 all_row_major = all_row_major && (mat2_stride[1] == 1);
369 lda[i] = self_stride[0];
370 ldb[i] = mat2_stride[0];
371 ldd[i] = mat2_size1;
372 }
373 auto dprops = at::cuda::getCurrentDeviceProperties();
374 bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
375 if (all_row_major &&
376 self.is_contiguous() &&
377 mat2.is_contiguous() &&
378 is_sm8x) {
379 success = group_gemm_dispatch<scalar_t>(
380 output.device(),
381 aptr,
382 bptr,
383 dptr,
384 lda,
385 ldb,
386 ldd,
387 gemm_sizes,
388 ntensors);
389 }
390 });
391 if (success) {
392 return output;
393 }
394 #endif
395 #endif
396
397 std::vector<Tensor> output_unbind = output.unbind();
398 for (int64_t i = 0; i < ntensors; i++) {
399 at::mm_out(output_unbind[i],
400 self_buffer.as_strided(get_size_for_index(self, i), get_stride_for_index(self, i), get_offset_for_index(self, i)),
401 mat2_buffer.as_strided(get_size_for_index(mat2, i), get_stride_for_index(mat2, i), get_offset_for_index(mat2, i)));
402 }
403 return output;
404 }
405
406 } // namespace native
407 } // namespace at
408