1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import math 4import os 5import weakref 6from functools import lru_cache 7from typing import Optional, Tuple 8 9import torch 10from torch._dynamo.utils import warn_once 11from torch.utils._triton import has_triton 12 13from ._triton_ops_meta import get_meta 14 15 16TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE = int( 17 os.getenv("TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE", 2) 18) 19 20 21def check(cond, msg): 22 if not cond: 23 raise ValueError(msg) 24 25 26def check_bsr_layout(f_name, t): 27 check( 28 t.layout == torch.sparse_bsr, 29 f"{f_name}(): only BSR sparse format is supported for the sparse argument.", 30 ) 31 32 33def check_device(f_name, t, device): 34 check( 35 t.device == device and t.device.type == "cuda", 36 f"{f_name}(): all inputs are expected to be on the same GPU device.", 37 ) 38 39 40def check_mm_compatible_shapes(f_name, lhs, rhs): 41 check( 42 lhs.dim() >= 2 and rhs.dim() >= 2, 43 f"{f_name}(): all inputs involved in the matrix product are expected to be at least 2D, " 44 f"but got lhs.dim() == {lhs.dim()} and rhs.dim() == {rhs.dim()}.", 45 ) 46 47 m, kl = lhs.shape[-2:] 48 kr, n = rhs.shape[-2:] 49 50 check( 51 kl == kr, 52 f"{f_name}(): arguments' sizes involved in the matrix product are not compatible for matrix multiplication, " 53 f"got lhs.shape[-1] == {kl} which is not equal to rhs.shape[-2] == {kr}.", 54 ) 55 56 57def check_dtype(f_name, t, dtype, *additional_dtypes): 58 check( 59 t.dtype == dtype 60 and t.dtype 61 in ((torch.half, torch.bfloat16, torch.float) + tuple(*additional_dtypes)), 62 f"{f_name}(): all inputs are expected to be of the same dtype " 63 f"and one of (half, bfloat16, float32) or {additional_dtypes}, " 64 f"but got dtype == {t.dtype}.", 65 ) 66 67 68def check_blocksize(f_name, blocksize): 69 assert len(blocksize) == 2 70 71 def is_power_of_two(v): 72 return not (v & (v - 1)) 73 74 def is_compatible_blocksize(b): 75 res = True 76 for blocksize in b: 77 # Triton loads only blocks which are at least 16 and powers of 2. 78 res = (blocksize >= 16 and is_power_of_two(blocksize)) and res 79 return res 80 81 check( 82 is_compatible_blocksize(blocksize), 83 f"{f_name}(): sparse inputs' blocksize ({blocksize[0]}, {blocksize[1]}) " 84 "should be at least 16 and a power of 2 in each dimension.", 85 ) 86 87 88def make_triton_contiguous(t): 89 """Return input as a triton-contiguous tensor. 90 91 A triton-contiguous tensor is defined as a tensor that has strides 92 with minimal value equal to 1. 93 94 While triton kernels support triton-non-contiguous tensors (all 95 strides being greater than 1 or having 0 strides) arguments, a 96 considerable slow-down occurs because tensor data is copied 97 element-wise rather than chunk-wise. 98 """ 99 if min(t.stride()) != 1: 100 # TODO: investigate if contiguity along other axes than the 101 # last one can be beneficial for performance 102 return t.contiguous() 103 else: 104 return t 105 106 107def broadcast_batch_dims(f_name, *tensors): 108 try: 109 return torch.broadcast_shapes(*(t.shape[:-2] for t in tensors)) 110 except Exception: 111 check(False, f"{f_name}(): inputs' batch dimensions are not broadcastable!") 112 113 114def slicer(dim, slice_range, *tensors): 115 for t in tensors: 116 slices = [slice(None)] * t.dim() 117 slices[dim] = slice_range 118 yield t[slices] 119 120 121def multidim_slicer(dims, slices, *tensors): 122 for t in tensors: 123 s = [slice(None)] * t.dim() 124 for d, d_slice in zip(dims, slices): 125 if d is not None: 126 s[d] = d_slice 127 yield t[s] 128 129 130def ptr_stride_extractor(*tensors): 131 for t in tensors: 132 yield t 133 yield from t.stride() 134 135 136def grid_partitioner(full_grid, grid_blocks, tensor_dims_map): 137 assert 0 <= len(full_grid) <= 3 138 assert 0 <= len(grid_blocks) <= 3 139 140 import itertools 141 142 def generate_grid_points(): 143 for fg, mg in zip(full_grid, grid_blocks): 144 yield range(0, fg, mg) 145 146 def generate_sliced_tensors(slices): 147 for t, t_dims in tensor_dims_map.items(): 148 yield next(multidim_slicer(t_dims, slices, t)) 149 150 for grid_point in itertools.product(*generate_grid_points()): 151 grid = [ 152 min(fg - gp, mg) for fg, gp, mg in zip(full_grid, grid_point, grid_blocks) 153 ] 154 slices = [slice(gp, gp + g) for gp, g in zip(grid_point, grid)] 155 # grid_points are iterated in a "contiguous" order, i.e. 156 # left dimensions traversed slower than right dimensions. 157 # This order is reversed for CUDA grids. 158 yield grid[::-1], *generate_sliced_tensors(slices) 159 160 161def launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks=None): 162 # cuda_max_grid = (2 ** 31 - 1, 2 ** 16 - 1, 2 ** 16 - 1) 163 cuda_max_grid = (2147483647, 65535, 65535)[::-1] 164 if grid_blocks is None: 165 grid_blocks = cuda_max_grid 166 else: 167 168 def valid_grid_dim(g, mg): 169 if g is None: 170 return mg 171 else: 172 # grid must be at least 1 and no greater than mg 173 return max(1, min(g, mg)) 174 175 grid_blocks = tuple( 176 valid_grid_dim(g, mg) for g, mg in zip(grid_blocks, cuda_max_grid) 177 ) # type: ignore[assignment] 178 179 for grid, *sliced_tensors in grid_partitioner( 180 full_grid, grid_blocks, tensor_dims_map 181 ): 182 kernel(grid, *sliced_tensors) 183 184 185def prepare_inputs(bsr, *dense_tensors): 186 # Introduce fake batch dimension if not present for convenience. 187 crow_indices = bsr.crow_indices().unsqueeze(0) 188 col_indices = bsr.col_indices().unsqueeze(0) 189 values = make_triton_contiguous(bsr.values().unsqueeze(0)) 190 tensors = [make_triton_contiguous(t.unsqueeze(0)) for t in dense_tensors] 191 192 # Compute broadcasted batch dimension 193 batch_dims_broadcasted = torch.broadcast_shapes( 194 values.shape[:-3], *(t.shape[:-2] for t in tensors) 195 ) 196 197 # Broadcast batch dimensions and squash. 198 # The result can be either a view or a copy. 199 def batch_broadcast_and_squash(t, batch_dims, invariant_dims): 200 return t.broadcast_to(batch_dims + invariant_dims).flatten( 201 0, len(batch_dims) - 1 202 ) 203 204 crow_indices = batch_broadcast_and_squash( 205 crow_indices, batch_dims_broadcasted, (-1,) 206 ) 207 208 col_indices = batch_broadcast_and_squash(col_indices, batch_dims_broadcasted, (-1,)) 209 values = batch_broadcast_and_squash( 210 values, batch_dims_broadcasted, values.shape[-3:] 211 ) 212 tensors = [ 213 batch_broadcast_and_squash(t, batch_dims_broadcasted, t.shape[-2:]) 214 for t in tensors 215 ] 216 217 return crow_indices, col_indices, values, *tensors 218 219 220def broadcast_batch_dims_bsr(f_name, bsr, *tensors): 221 batch_shape = broadcast_batch_dims(f_name, bsr, *tensors) 222 223 crow_indices = bsr.crow_indices().broadcast_to(batch_shape + (-1,)) 224 col_indices = bsr.col_indices().broadcast_to(batch_shape + (-1,)) 225 values = bsr.values().broadcast_to(batch_shape + bsr.values().shape[-3:]) 226 size = batch_shape + bsr.shape[-2:] 227 return torch.sparse_compressed_tensor( 228 crow_indices, col_indices, values, size=size, layout=bsr.layout 229 ) 230 231 232# NOTE: this function will ALWAYS create a view 233def tile_to_blocksize(t, blocksize): 234 *rest, m, n = t.shape 235 new_shape = rest + [ 236 m // blocksize[0], 237 blocksize[0], 238 n // blocksize[1], 239 blocksize[1], 240 ] 241 # using .view instead of .reshape to ensure that the result is 242 # indeed a view: 243 return t.view(new_shape).transpose(-3, -2) 244 245 246def as1Dbatch(tensor): 247 """Return tensor as 3D tensor by either prepending new dimensions to 248 the tensor shape (when ``tensor.ndim < 3``), or by collapsing 249 starting dimensions into the first dimension (when ``tensor.ndim > 250 3``). 251 """ 252 while tensor.ndim < 3: 253 tensor = tensor.unsqueeze(0) 254 if tensor.ndim > 3: 255 tensor = tensor.flatten(0, tensor.ndim - 3) 256 assert tensor.ndim == 3, tensor.shape 257 return tensor 258 259 260def scatter_mm(blocks, others, indices_data, *, accumulators=None): 261 """Scattered matrix multiplication of tensors. 262 263 A scattered matrix multiplication is defined as a series of matrix 264 multiplications applied to input tensors according to the input 265 and output mappings specified by indices data. 266 267 The following indices data formats are supported for defining a 268 scattered matrix multiplication operation (:attr:`indices_data[0]` 269 holds the name of the indices data format as specified below): 270 271 - ``"scatter_mm"`` - matrix multiplications scattered in batches 272 of tensors. 273 274 If :attr:`blocks` is a :math:`(* \times M \times K) tensor, 275 :attr:`others` is a :math:`(* \times K \times N)` tensor, 276 :attr:`accumulators` is a :math:`(* \times M \times N)` tensor, 277 and :attr:`indices = indices_data['indices']` is a :math:`(* 278 \times 3)` tensor, then the operation is equivalent to the 279 following code:: 280 281 c_offsets, pq = indices_data[1:] 282 for r in range(len(c_offsets) - 1): 283 for g in range(c_offsets[r], c_offsets[r + 1]): 284 p, q = pq[g] 285 accumulators[r] += blocks[p] @ others[q] 286 287 - ``"bsr_strided_mm"`` - matrix multiplications scattered in 288 batches of tensors and a tensor. 289 290 If :attr:`blocks` is a :math:`(Ms \times Ks) tensor, 291 :attr:`others` is a :math:`(* \times K \times N)` tensor, 292 :attr:`accumulators` is a :math:`(* \times M \times N)` tensor, then 293 the operation is equivalent to the following code:: 294 295 c_indices, r_offsets, p_offsets, q_offsets, meta = indices_data[1:] 296 for b in range(nbatches): 297 for i, r in enumerate(r_offsets): 298 r0, r1 = divmod(r, N) 299 acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns] 300 for g in range(c_indices[i], c_indices[i+1]): 301 p = p_offsets[g] 302 q0, q1 = divmod(q_offsets[g], N) 303 acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns] 304 305 where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are 306 integer multiples of ``Ms`` and ``Ks``, respectively. 307 308 - ``"bsr_strided_mm_compressed"`` - matrix multiplications 309 scattered in batches of tensors and a tensor. A memory and 310 processor efficient version of ``"bsr_strided_mm"`` format. If 311 :attr:`blocks` is a :math:`(Ms \times Ks) tensor, :attr:`others` 312 is a :math:`(* \times K \times N)` tensor, :attr:`accumulators` 313 is a :math:`(* \times M \times N)` tensor, then the operation is 314 equivalent to the following code:: 315 316 c_indices, r_offsets, q_offsets, meta = indices_data[1:] 317 for b in range(nbatches): 318 for r in r_offsets: 319 m = (r // N) // Ms 320 n = (r % N) // Ns 321 r0, r1 = divmod(r, N) 322 c0, c1 = c_indices[m], c_indices[m + 1] 323 acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns] 324 for i, p in enumerate(range(c0, c1)): 325 q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i] 326 q0, q1 = divmod(q, N) 327 acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns] 328 329 where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are 330 integer multiples of ``Ms`` and ``Ks``, respectively. 331 332 Notice that the order of ``r_offsets`` items can be arbitrary; 333 this property enables defining swizzle operators via 334 rearrangements of ``r_offsets`` items.. 335 336 Auxilary functions are provided for pre-computing 337 :attr:`indices_data`. For example, 338 :func:`bsr_scatter_mm_indices_data` is used to define indices data 339 for matrix multiplication of BSR and strided tensors. 340 341 Parameters 342 ---------- 343 blocks (Tensor): a 3-D tensor of first matrices to be multiplied 344 345 others (Tensor): a tensor of second matrices to be multiplied. If 346 ``indices_data[0]=="scatter_mm"``, the tensor is a 1-D batch 347 tensor of second input matrices to be multiplied. Otherwise, the 348 second input matrices are slices of the :attr:`others` tensor. 349 indices_data (tuple): a format data that defines the inputs and 350 outputs of scattered matrix multiplications. 351 352 Keyword arguments 353 ----------------- 354 355 accumulators (Tensor, optional): a tensor of matrix product 356 accumulators. If ``indices_data[0]=="scatter_mm"``, the tensor 357 is a 1-D batch tensor of output matrices. Otherwise, output 358 matrices are slices of the :attr:`accumulators` tensor. 359 """ 360 indices_format = indices_data[0] 361 362 assert blocks.ndim == 3 363 P, Ms, Ks = blocks.shape 364 365 if indices_format == "scatter_mm": 366 c_offsets, pq = indices_data[1:] 367 368 assert others.ndim == 3 369 Q, Ks_, Ns = others.shape 370 assert Ks == Ks_ 371 372 if accumulators is None: 373 R = c_offsets.shape[0] - 1 374 accumulators = torch.zeros( 375 (R, Ms, Ns), dtype=blocks.dtype, device=blocks.device 376 ) 377 else: 378 R, Ms_, Ns_ = accumulators.shape 379 assert Ms_ == Ms 380 assert Ns_ == Ns 381 382 if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm2 is None: 383 for r in range(c_offsets.shape[0] - 1): 384 g0 = c_offsets[r] 385 g1 = c_offsets[r + 1] 386 for g in range(g0, g1): 387 p, q = pq[g] 388 accumulators[r] += blocks[p] @ others[q] 389 else: 390 _scatter_mm2(blocks, others, c_offsets, pq, accumulators) 391 return accumulators 392 393 elif indices_format == "bsr_strided_mm": 394 others_shape = others.shape 395 others = as1Dbatch(others) 396 397 B, K, N = others.shape 398 assert K % Ks == 0 399 400 c_indices, r_offsets, p_offsets, q_offsets, meta = indices_data[1:] 401 SPLIT_N = meta["SPLIT_N"] 402 403 if accumulators is None: 404 M = Ms + (r_offsets.max().item() + 1) // N 405 accumulators = torch.zeros( 406 (*others_shape[:-2], M, N), dtype=blocks.dtype, device=blocks.device 407 ) 408 else: 409 M, N_ = accumulators.shape[-2:] 410 assert N_ == N 411 412 accumulators_shape = accumulators.shape 413 accumulators = as1Dbatch(accumulators) 414 415 Ns = N // SPLIT_N 416 417 if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm6 is None: 418 accumulators.zero_() 419 for b in range(B): 420 for r in range(r_offsets.shape[0]): 421 r_ = r_offsets[r].item() 422 g0 = c_indices[r].item() 423 g1 = c_indices[r + 1].item() 424 r0, r1 = divmod(r_, N) 425 acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns] 426 for g in range(g0, g1): 427 p, q = p_offsets[g], q_offsets[g] 428 q0, q1 = divmod(q.item(), N) 429 acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns] 430 else: 431 _scatter_mm6( 432 blocks, 433 others, 434 c_indices, 435 r_offsets, 436 p_offsets, 437 q_offsets, 438 meta, 439 accumulators, 440 ) 441 return accumulators.view(accumulators_shape) 442 443 elif indices_format == "bsr_strided_mm_compressed": 444 others_shape = others.shape 445 others = as1Dbatch(others) 446 447 B, K, N = others.shape 448 assert K % Ks == 0 449 450 c_indices, r_offsets, q_offsets, meta = indices_data[1:] 451 SPLIT_N = meta["SPLIT_N"] 452 453 if accumulators is None: 454 M = Ms + (r_offsets.max().item() + 1) // N 455 accumulators = torch.zeros( 456 (*others_shape[:-2], M, N), dtype=blocks.dtype, device=blocks.device 457 ) 458 else: 459 M, N_ = accumulators.shape[-2:] 460 assert N_ == N 461 462 accumulators_shape = accumulators.shape 463 accumulators = as1Dbatch(accumulators) 464 465 Ns = N // SPLIT_N 466 467 if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm6 is None: 468 for b in range(B): 469 for j in range(len(r_offsets)): 470 r0, r1 = divmod(r_offsets[j].item(), N) 471 m = r0 // Ms 472 n = r1 // Ns 473 c0 = c_indices[m].item() 474 c1 = c_indices[m + 1].item() 475 acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns] 476 for i, p in enumerate(range(c0, c1)): 477 q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i].item() 478 q0, q1 = divmod(q, N) 479 acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns] 480 else: 481 p_offsets = torch.empty( 482 (0,), dtype=q_offsets.dtype, device=q_offsets.device 483 ) 484 _scatter_mm6( 485 blocks, 486 others, 487 c_indices, 488 r_offsets, 489 p_offsets, 490 q_offsets, 491 meta, 492 accumulators, 493 ) 494 return accumulators.view(accumulators_shape) 495 496 else: 497 raise NotImplementedError(indices_format) 498 499 500def scatter_mm_meta( 501 M, 502 K, 503 N, 504 Ms, 505 Ks, 506 GROUP_SIZE=None, 507 TILE_M=None, 508 TILE_N=None, 509 SPLIT_N=None, 510 num_warps=None, 511 num_stages=None, 512 **extra, 513): 514 if {TILE_M, TILE_N, SPLIT_N, num_warps, num_stages, GROUP_SIZE} == {None}: 515 device_name = torch.cuda.get_device_name() 516 meta = get_meta( 517 "scatter_mm", 518 (M, K, N, Ms, Ks), 519 device_name, 520 version=(0, torch.float16, 0.5), 521 ) 522 if meta is not None: 523 meta.update(**extra) 524 return meta 525 # The following parameters are optimized for the performance 526 # equilibrium points of bsr-dense and dense-dense matrix 527 # multiplications when using GPU card NVIDIA GeForce RTX 2060 528 # SUPER. For points far from the performance equilibrium 529 # points as well as for other GPU cards, the optimal 530 # parameters are likely different from what specified below. 531 if (M, K, N) == (256,) * 3: 532 if (Ms, Ks) == (16, 16): 533 SPLIT_N = 1 534 TILE_M = 16 535 TILE_N = 16 536 GROUP_SIZE = 4 537 num_stages = 1 538 num_warps = 4 # noqa: E225,E231,E702 539 elif (Ms, Ks) == (32, 32): 540 SPLIT_N = 2 541 TILE_M = 32 542 TILE_N = 16 543 GROUP_SIZE = 4 544 num_stages = 1 545 num_warps = 4 # noqa: E225,E231,E702 546 elif (Ms, Ks) == (64, 64): 547 SPLIT_N = 1 548 TILE_M = 32 549 TILE_N = 32 550 GROUP_SIZE = 4 551 num_stages = 1 552 num_warps = 4 # noqa: E225,E231,E702 553 elif (Ms, Ks) == (128, 128): 554 SPLIT_N = 1 555 TILE_M = 32 556 TILE_N = 32 557 GROUP_SIZE = 2 558 num_stages = 1 559 num_warps = 4 # noqa: E225,E231,E702 560 elif (M, K, N) == (512,) * 3: 561 if (Ms, Ks) == (16, 16): 562 SPLIT_N = 8 563 TILE_M = 16 564 TILE_N = 64 565 GROUP_SIZE = 2 566 num_stages = 1 567 num_warps = 2 # noqa: E225,E231,E702 568 elif (Ms, Ks) == (32, 32): 569 SPLIT_N = 8 570 TILE_M = 32 571 TILE_N = 64 572 GROUP_SIZE = 4 573 num_stages = 1 574 num_warps = 2 # noqa: E225,E231,E702 575 elif (Ms, Ks) == (64, 64): 576 SPLIT_N = 4 577 TILE_M = 32 578 TILE_N = 128 579 GROUP_SIZE = 4 580 num_stages = 1 581 num_warps = 4 # noqa: E225,E231,E702 582 elif (Ms, Ks) == (128, 128): 583 SPLIT_N = 8 584 TILE_M = 64 585 TILE_N = 64 586 GROUP_SIZE = 4 587 num_stages = 1 588 num_warps = 4 # noqa: E225,E231,E702 589 elif (M, K, N) == (1024,) * 3: 590 if (Ms, Ks) == (16, 16): 591 SPLIT_N = 4 592 TILE_M = 16 593 TILE_N = 128 594 GROUP_SIZE = 2 595 num_stages = 1 596 num_warps = 1 # noqa: E225,E231,E702 597 elif (Ms, Ks) == (32, 32): 598 SPLIT_N = 8 599 TILE_M = 32 600 TILE_N = 64 601 GROUP_SIZE = 2 602 num_stages = 1 603 num_warps = 1 # noqa: E225,E231,E702 604 elif (Ms, Ks) == (64, 64): 605 SPLIT_N = 16 606 TILE_M = 64 607 TILE_N = 64 608 GROUP_SIZE = 4 609 num_stages = 1 610 num_warps = 2 # noqa: E225,E231,E702 611 elif (Ms, Ks) == (128, 128): 612 SPLIT_N = 16 613 TILE_M = 64 614 TILE_N = 64 615 GROUP_SIZE = 4 616 num_stages = 1 617 num_warps = 4 # noqa: E225,E231,E702 618 elif (Ms, Ks) == (256, 256): 619 SPLIT_N = 16 620 TILE_M = 64 621 TILE_N = 64 622 GROUP_SIZE = 2 623 num_stages = 1 624 num_warps = 4 # noqa: E225,E231,E702 625 elif (M, K, N) == (2048,) * 3: 626 if (Ms, Ks) == (16, 16): 627 SPLIT_N = 4 628 TILE_M = 16 629 TILE_N = 128 630 GROUP_SIZE = 8 631 num_stages = 1 632 num_warps = 1 # noqa: E225,E231,E702 633 elif (Ms, Ks) == (32, 32): 634 SPLIT_N = 4 635 TILE_M = 32 636 TILE_N = 64 637 GROUP_SIZE = 4 638 num_stages = 1 639 num_warps = 1 # noqa: E225,E231,E702 640 elif (Ms, Ks) == (64, 64): 641 SPLIT_N = 4 642 TILE_M = 64 643 TILE_N = 128 644 GROUP_SIZE = 4 645 num_stages = 1 646 num_warps = 4 # noqa: E225,E231,E702 647 elif (Ms, Ks) == (128, 128): 648 SPLIT_N = 8 649 TILE_M = 64 650 TILE_N = 64 651 GROUP_SIZE = 4 652 num_stages = 1 653 num_warps = 4 # noqa: E225,E231,E702 654 elif (Ms, Ks) == (256, 256): 655 SPLIT_N = 4 656 TILE_M = 64 657 TILE_N = 64 658 GROUP_SIZE = 2 659 num_stages = 1 660 num_warps = 4 # noqa: E225,E231,E702 661 elif (M, K, N) == (4096,) * 3: 662 if (Ms, Ks) == (16, 16): 663 SPLIT_N = 2 664 TILE_M = 16 665 TILE_N = 256 666 GROUP_SIZE = 2 667 num_stages = 1 668 num_warps = 2 # noqa: E225,E231,E702 669 elif (Ms, Ks) == (32, 32): 670 SPLIT_N = 2 671 TILE_M = 32 672 TILE_N = 64 673 GROUP_SIZE = 2 674 num_stages = 1 675 num_warps = 1 # noqa: E225,E231,E702 676 elif (Ms, Ks) == (64, 64): 677 SPLIT_N = 2 678 TILE_M = 64 679 TILE_N = 128 680 GROUP_SIZE = 2 681 num_stages = 1 682 num_warps = 4 # noqa: E225,E231,E702 683 684 if SPLIT_N is None: 685 # Assume NVIDIA GeForce RTX 2060 SUPER: 686 # With the probality of 92% (99.9% when N > 512), the 687 # performance will not be worse more than 2% from the 688 # performance when using an optimal value. Otherwise, when N 689 # <= 512, using the following heuristics may give upto 15% 690 # lower performance. 691 SPLIT_N = { 692 16: 1, 693 32: 2, 694 64: 4, 695 128: 8, 696 256: 16, 697 512: 8, 698 1024: 16, 699 4096: 32, 700 8192: 64, 701 }.get(N, 16) 702 if Ms >= 512 and N >= 2048: 703 SPLIT_N = 1 704 Ns = N // SPLIT_N 705 if TILE_M is None: 706 TILE_M = min(64 if Ns < 512 else 32, Ms) 707 if TILE_N is None: 708 TILE_N = min(64 if Ns < 512 else 32, Ns) 709 num_stages = num_stages or 1 710 if num_warps is None: 711 if min(M, N) > 1024: 712 num_warps = {16: 1, 32: 1, 64: 2}.get(Ms, 4) 713 elif min(M, N) == 1024: 714 num_warps = {16: 1, 32: 1, 64: 2}.get(Ms, 4) 715 elif min(M, N) == 256: 716 num_warps = {16: 1, 32: 4}.get(Ms, 4) 717 else: 718 num_warps = {16: 1, 32: 2}.get(Ms, 4) 719 GROUP_SIZE = GROUP_SIZE or 4 720 721 assert TILE_M <= Ms, dict(TILE_M=TILE_M, Ms=Ms) 722 assert TILE_N <= Ns, dict(TILE_N=TILE_N, Ns=Ns) 723 assert Ms <= M, dict(M=M, Ms=Ms) 724 assert Ns <= N, dict(N=N, Ns=Ns) 725 assert Ks <= K, dict(K=K, Ks=Ks) 726 727 return dict( 728 TILE_M=TILE_M, 729 TILE_N=TILE_N, 730 GROUP_SIZE=GROUP_SIZE, 731 num_stages=num_stages, 732 num_warps=num_warps, 733 SPLIT_N=SPLIT_N, 734 **extra, 735 ) 736 737 738def bsr_dense_addmm_meta( 739 M, 740 K, 741 N, 742 Ms, 743 Ks, 744 beta, 745 alpha, 746 SPLIT_N=None, 747 GROUP_SIZE_ROW=None, 748 num_warps=None, 749 num_stages=None, 750 sparsity=None, 751 dtype=None, 752 _version=0, 753 **extra, 754): 755 # Specifying _version is useful for situations when one wants to 756 # discard existing triton kernel tuning results, say, in testing 757 # bsr_dense_addmm_meta functionality. 758 if dtype is None: 759 dtype = torch.float16 760 if sparsity is None: 761 sparsity = 0.5 762 if {SPLIT_N, num_warps, num_stages, GROUP_SIZE_ROW} == {None}: 763 device_name = torch.cuda.get_device_name() 764 key = (M, K, N, Ms, Ks, beta == 0, beta == 1, alpha == 1) 765 meta = get_meta( 766 "bsr_dense_addmm", key, device_name, version=(_version, dtype, sparsity) 767 ) 768 if meta is None and sparsity != 0.5: 769 meta = get_meta( 770 "bsr_dense_addmm", key, device_name, version=(_version, dtype, 0.5) 771 ) 772 if meta is None: 773 # find approximate meta such that N % SPLIT_N == 0. 774 matching_meta = get_meta( 775 "bsr_dense_addmm", 776 (*key[:2], "*", *key[3:]), 777 device_name, 778 version=(_version, dtype, 0.5), 779 ) 780 for mkey in sorted(matching_meta or {}): 781 meta_ = matching_meta[mkey] 782 n = mkey[2] 783 split_n = meta_["SPLIT_N"] 784 c = n // split_n 785 if N % c == 0 and n <= N: 786 meta = dict(meta_) 787 meta["SPLIT_N"] = N // c 788 if meta is not None: 789 meta.update(**extra) 790 return meta 791 else: 792 # see [Computing optimal kernel parameters] in 793 # _triton_ops_meta.py for ways to avoid this warning 794 # message 795 warn_once( 796 f"bsr_dense_addmm uses non-optimal triton kernel parameters for {M=} {K=} {N=} {Ms=}, {Ks=} {beta=} {alpha=}" 797 ) 798 799 SPLIT_N = SPLIT_N or max(N // Ms, 1) 800 GROUP_SIZE_ROW = GROUP_SIZE_ROW or 4 801 num_stages = num_stages or 1 802 num_warps = num_warps or 4 803 return dict( 804 SPLIT_N=SPLIT_N, 805 GROUP_SIZE_ROW=GROUP_SIZE_ROW, 806 num_stages=num_stages, 807 num_warps=num_warps, 808 **extra, 809 ) 810 811 812class TensorAsKey: 813 """A light-weight wrapper of a tensor that enables storing tensors as 814 keys with efficient memory reference based comparision as an 815 approximation to data equality based keys. 816 817 Motivation: the hash value of a torch tensor is tensor instance 818 based that does not use data equality and makes the usage of 819 tensors as keys less useful. For instance, the result of 820 ``len({a.crow_indices(), a.crow_indices()})`` is `2`, although, 821 the tensor results from `crow_indices` method call are equal, in 822 fact, these share the same data storage. 823 On the other hand, for efficient caching of tensors we want to 824 avoid calling torch.equal that compares tensors item-wise. 825 826 TensorAsKey offers a compromise in that it guarantees key equality 827 of tensors that references data in the same storage in the same 828 manner and without accessing underlying data. However, this 829 approach does not always guarantee correctness. For instance, for 830 a complex tensor ``x``, we have ``TensorAsKey(x) == 831 TensorAsKey(x.conj())`` while ``torch.equal(x, x.conj())`` would 832 return False. 833 """ 834 835 def __init__(self, obj): 836 def get_tensor_key(obj): 837 # Warning: TensorAsKey does not track negative nor 838 # conjugate bits of its input object because in the use 839 # case of wrapping compressed/plain indices of compressed 840 # sparse tensors (that are always integer tensors with 841 # non-negative items) these bits are never set. However, 842 # when extending the use of TensorAsKey to float or 843 # complex tensors, the values of these bits (see is_neg 844 # and is_conj methods) must be included in the key as 845 # well. 846 assert not (obj.dtype.is_floating_point or obj.dtype.is_complex), obj.dtype 847 return ( 848 obj.data_ptr(), 849 obj.storage_offset(), 850 obj.shape, 851 obj.stride(), 852 obj.dtype, 853 ) 854 855 self._obj_ref = weakref.ref(obj) 856 if obj.layout is torch.strided: 857 self.key = get_tensor_key(obj) 858 elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}: 859 self.key = ( 860 get_tensor_key(obj.crow_indices()), 861 get_tensor_key(obj.col_indices()), 862 ) 863 elif obj.layout in {torch.sparse_csc, torch.sparse_bsc}: 864 self.key = ( 865 get_tensor_key(obj.ccol_indices()), 866 get_tensor_key(obj.row_indices()), 867 ) 868 else: 869 raise NotImplementedError(obj.layout) 870 self._hash = hash(self.key) 871 872 def __hash__(self): 873 return self._hash 874 875 def __eq__(self, other): 876 if not isinstance(other, TensorAsKey): 877 return False 878 if self.obj is None or other.obj is None: 879 # dead objects always compare unequal unless these are 880 # same objects 881 return self is other 882 return self.key == other.key 883 884 @property 885 def obj(self): 886 """Return object if alive, otherwise None.""" 887 return self._obj_ref() 888 889 890@lru_cache(maxsize=TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE) 891def _bsr_scatter_mm_indices_data( 892 indices_format, M, K, N, Ms, Ks, nbatches, SPLIT_N, compressed_sparse_tensor_as_key 893): 894 bsr = compressed_sparse_tensor_as_key.obj 895 assert bsr is not None 896 crow_indices, col_indices = bsr.crow_indices(), bsr.col_indices() 897 device = crow_indices.device 898 indices_dtype = torch.int32 899 900 if indices_format == "bsr_strided_mm_compressed": 901 Ns = N // SPLIT_N 902 q_offsets_lst = [] 903 b = torch.arange(SPLIT_N, dtype=indices_dtype, device=device) * Ns 904 for m in range(M // Ms): 905 r0 = crow_indices[m].item() 906 r1 = crow_indices[m + 1].item() 907 if r1 == r0: 908 continue 909 q_offsets_lst.append( 910 (col_indices[r0:r1] * (Ks * N)).repeat(SPLIT_N) 911 + b.repeat_interleave(r1 - r0) 912 ) 913 q_offsets = torch.cat(q_offsets_lst) 914 crow_indices_diff = crow_indices.diff() 915 non_zero_row_indices = crow_indices_diff.nonzero() 916 a = non_zero_row_indices * (Ms * N) 917 r_offsets = (a + b).view(-1) 918 c_indices = crow_indices 919 # swizzle operation: mm elements with longer sums are computed first: 920 nnz_per_row = crow_indices_diff[non_zero_row_indices].repeat_interleave(SPLIT_N) 921 nnz_per_row, indices = nnz_per_row.sort(descending=True, stable=True) 922 r_offsets = r_offsets[indices] 923 return (indices_format, c_indices, r_offsets, q_offsets) 924 925 elif indices_format == "bsr_strided_mm": 926 Ns = N // SPLIT_N 927 p_offsets_lst = [] 928 q_offsets_lst = [] 929 b = torch.arange(SPLIT_N, dtype=indices_dtype, device=device) * Ns 930 for m in range(M // Ms): 931 r0 = crow_indices[m].item() 932 r1 = crow_indices[m + 1].item() 933 if r1 == r0: 934 continue 935 p_offsets_lst.append( 936 torch.arange(r0, r1, dtype=indices_dtype, device=device).repeat(SPLIT_N) 937 ) 938 q_offsets_lst.append( 939 (col_indices[r0:r1] * (Ks * N)).repeat(SPLIT_N) 940 + b.repeat_interleave(r1 - r0) 941 ) 942 q_offsets = torch.cat(q_offsets_lst) 943 crow_indices_diff = crow_indices.diff() 944 non_zero_row_indices = crow_indices_diff.nonzero() 945 a = non_zero_row_indices * (Ms * N) 946 r_offsets = (a + b).view(-1) 947 c_indices = torch.cat( 948 ( 949 crow_indices[:1], 950 torch.cumsum( 951 crow_indices_diff[non_zero_row_indices].repeat_interleave(SPLIT_N), 952 0, 953 ), 954 ) 955 ) 956 p_offsets = torch.cat(p_offsets_lst) 957 return (indices_format, c_indices, r_offsets, p_offsets, q_offsets) 958 959 elif indices_format == "scatter_mm": 960 Ns = Ms 961 c_indices = [0] 962 pq_offsets = [] 963 # todo: eliminate inner for-loops for efficiency 964 for b in range(nbatches): 965 for m in range(M // Ms): 966 r0 = crow_indices[m].item() 967 r1 = crow_indices[m + 1].item() 968 for n in range(N // Ns): 969 c_indices.append(c_indices[-1] + r1 - r0) 970 for t in range(r1 - r0): 971 p = r0 + t 972 q = (col_indices[p].item() + b * (K // Ks)) * (N // Ns) + n 973 pq_offsets.append([p, q]) 974 975 return ( 976 indices_format, 977 torch.tensor(c_indices, dtype=indices_dtype, device=device), 978 torch.tensor(pq_offsets, dtype=indices_dtype, device=device), 979 ) 980 981 else: 982 raise ValueError( 983 f"Invalid {indices_format=}. Expected bsr_strided_mm_compressed|bsr_strided_mm|scatter_mm" 984 ) 985 986 987def bsr_scatter_mm_indices_data( 988 bsr, other, indices_format="bsr_strided_mm_compressed", **meta_input 989): 990 """Computes indices data for :func:`scatter_mm` used in BSR and 991 strided tensor matrix multiplication. 992 """ 993 assert bsr.dense_dim() == 0 994 assert bsr.ndim == 2 # no batch dims 995 crow_indices = bsr.crow_indices() 996 col_indices = bsr.col_indices() 997 blocksize = bsr.values().shape[-2:] 998 M, K = bsr.shape 999 Ms, Ks = blocksize 1000 K_, N = other.shape[-2:] 1001 assert K_ == K 1002 nbatches = other.shape[:-2].numel() 1003 1004 meta = scatter_mm_meta(M, K, N, Ms, Ks, **meta_input) 1005 if "allow_tf32" not in meta_input: 1006 meta.update(allow_tf32=bsr.dtype in {torch.float16, torch.bfloat16}) 1007 SPLIT_N = meta["SPLIT_N"] 1008 indices_data = _bsr_scatter_mm_indices_data( 1009 indices_format, M, K, N, Ms, Ks, nbatches, SPLIT_N, TensorAsKey(bsr) 1010 ) 1011 1012 if indices_format == "bsr_strided_mm_compressed": 1013 meta.update(is_compressed=True) 1014 return indices_data + (meta,) 1015 elif indices_format == "bsr_strided_mm": 1016 meta.update(is_compressed=False) 1017 return indices_data + (meta,) 1018 else: 1019 return indices_data 1020 1021 1022def bsr_scatter_mm(bsr, other, indices_data=None, out=None): 1023 """BSR @ strided -> strided""" 1024 1025 assert bsr.ndim == 2 1026 assert other.ndim >= 2 1027 1028 Ms, Ks, Ns = bsr.shape[-2], bsr.shape[-1], other.shape[-1] 1029 blocksize = bsr.values().shape[-2:] 1030 1031 if indices_data is None: 1032 indices_data = bsr_scatter_mm_indices_data( 1033 bsr, other, indices_format="bsr_strided_mm_compressed" 1034 ) 1035 1036 indices_format = indices_data[0] 1037 1038 if out is None: 1039 out = torch.empty( 1040 (*other.shape[:-2], Ms, Ns), dtype=bsr.dtype, device=bsr.device 1041 ) 1042 out_shape = out.shape 1043 out = as1Dbatch(out) 1044 1045 if bsr._nnz() == 0: 1046 out.zero_() 1047 elif indices_format in {"bsr_strided_mm_compressed", "bsr_strided_mm"}: 1048 out.zero_() 1049 scatter_mm(bsr.values(), other, indices_data, accumulators=out) 1050 elif indices_format == "scatter_mm": 1051 nbatches = other.shape[:-2].numel() 1052 accumulators = torch.zeros( 1053 ( 1054 nbatches * Ms // blocksize[0] * Ns // blocksize[0], 1055 blocksize[0], 1056 blocksize[0], 1057 ), 1058 dtype=bsr.dtype, 1059 device=bsr.device, 1060 ) 1061 others = ( 1062 as1Dbatch(other) 1063 .transpose(-2, -1) 1064 .view( 1065 nbatches, 1066 Ns // blocksize[0], 1067 blocksize[0], 1068 Ks // blocksize[1], 1069 blocksize[1], 1070 ) 1071 .movedim( 1072 (3, 1, 4, 2), (1, 2, 3, 4) 1073 ) # equivalent to .transpose(-3, -2).transpose(-2, -1).transpose(-4, -3) 1074 .flatten(0, 2) 1075 ) 1076 scatter_mm(bsr.values(), others, indices_data, accumulators=accumulators) 1077 out.copy_( 1078 accumulators.unflatten( 1079 0, (nbatches, Ms // blocksize[0], Ns // blocksize[0]) 1080 ) 1081 .movedim( 1082 (1, 2, 3, 4), (3, 1, 4, 2) 1083 ) # equivalent to .transpose(-4, -3).transpose(-2, -1).transpose(-3, -2) 1084 .reshape(nbatches, Ns, Ms) 1085 .transpose(-2, -1) 1086 ) 1087 else: 1088 raise NotImplementedError(indices_format) 1089 1090 return out.view(out_shape) 1091 1092 1093def _int_bsr_dense_addmm( 1094 input: torch.Tensor, 1095 bsr: torch.Tensor, 1096 dense: torch.Tensor, 1097 *, 1098 beta=1, 1099 alpha=1, 1100 out: Optional[torch.Tensor] = None, 1101 skip_checks: bool = False, 1102 max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, 1103 meta: Optional[dict] = None, 1104): 1105 if out is None and dense.dtype is torch.int8: 1106 f_name = "_int_bsr_dense_addmm" 1107 crow_indices = bsr.crow_indices() 1108 batch_ndim = crow_indices.dim() - 1 1109 M = bsr.shape[batch_ndim] 1110 N = dense.shape[-1] 1111 original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense) 1112 out = torch.empty( 1113 original_batch_dims_broadcasted + (M, N), 1114 dtype=torch.int32, 1115 device=dense.device, 1116 ) 1117 return bsr_dense_addmm( 1118 input, 1119 bsr, 1120 dense, 1121 beta=beta, 1122 alpha=alpha, 1123 out=out, 1124 skip_checks=skip_checks, 1125 max_grid=max_grid, 1126 meta=meta, 1127 ) 1128 1129 1130def bsr_dense_addmm( 1131 input: torch.Tensor, 1132 bsr: torch.Tensor, 1133 dense: torch.Tensor, 1134 *, 1135 beta=1, 1136 alpha=1, 1137 out: Optional[torch.Tensor] = None, 1138 skip_checks: bool = False, 1139 max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, 1140 meta: Optional[dict] = None, 1141): 1142 f_name = "bsr_dense_addmm" 1143 values = bsr.values() 1144 crow_indices = bsr.crow_indices() 1145 col_indices = bsr.col_indices() 1146 batch_ndim = crow_indices.dim() - 1 1147 M, K = bsr.shape[batch_ndim : batch_ndim + 2] 1148 blocksize = values.shape[batch_ndim + 1 : batch_ndim + 3] 1149 N = dense.shape[-1] 1150 1151 # todo: implement checks 1152 1153 if out is None: 1154 original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense) 1155 out = dense.new_empty(original_batch_dims_broadcasted + (M, N)) 1156 1157 if bsr._nnz() == 0 or alpha == 0 or N == 0 or M == 0 or K == 0: 1158 if beta == 0: 1159 out.zero_() 1160 else: 1161 out.copy_(input) 1162 if beta != 1: 1163 out.mul_(beta) 1164 return out 1165 1166 if meta is None: 1167 sparsity = round(1 - bsr._nnz() * blocksize[0] * blocksize[1] / (M * K), 2) 1168 meta = bsr_dense_addmm_meta( 1169 M, 1170 K, 1171 N, 1172 blocksize[0], 1173 blocksize[1], 1174 beta, 1175 alpha, 1176 sparsity=sparsity, 1177 dtype=out.dtype, 1178 ) 1179 out_backup = out 1180 1181 crow_indices, col_indices, values, input, dense, out = prepare_inputs( 1182 bsr, input, dense, out 1183 ) 1184 1185 BM, BK = blocksize 1186 SPLIT_N = meta.get("SPLIT_N", N // BM) 1187 BN = N // SPLIT_N 1188 1189 out_untiled = out 1190 out = tile_to_blocksize(out, (BM, BN)) 1191 dense = tile_to_blocksize(dense, (BK, BN)) 1192 input = tile_to_blocksize(input, (BM, BN)) 1193 1194 dot_out_dtype = { 1195 torch.float16: tl.float32, 1196 torch.bfloat16: tl.float32, 1197 torch.float32: tl.float64, 1198 torch.float64: tl.float64, 1199 torch.int8: tl.int32, 1200 torch.int32: tl.int32, 1201 }[out.dtype] 1202 1203 n_batches = dense.size(0) 1204 n_block_rows = crow_indices.size(-1) - 1 1205 n_block_cols = dense.size(-3) 1206 1207 full_grid = (n_batches, n_block_cols, n_block_rows) 1208 if max_grid is not None: 1209 grid_blocks = tuple(max_grid[:3][::-1]) + (None,) * (3 - len(max_grid[:3])) 1210 else: 1211 grid_blocks = None 1212 1213 tensor_dims_map = { 1214 values: (0, None, None), 1215 crow_indices: (0, None, -1), 1216 col_indices: (0, None, None), 1217 input: (0, -3, -4), 1218 dense: (0, -3, None), 1219 out: (0, -3, -4), 1220 } 1221 1222 assert alpha != 0 1223 1224 def kernel(grid, *sliced_tensors): 1225 _bsr_strided_addmm_kernel[grid]( 1226 *ptr_stride_extractor(*sliced_tensors), 1227 beta, 1228 alpha, 1229 beta_is_one=beta == 1, 1230 beta_is_nonzero=beta != 0, 1231 alpha_is_one=alpha == 1, 1232 BLOCKSIZE_ROW=BM, 1233 BLOCKSIZE_INNER=BK, 1234 BLOCKSIZE_COL=BN, 1235 allow_tf32=dot_out_dtype == tl.float32, 1236 acc_dtype=dot_out_dtype, 1237 **meta, 1238 ) 1239 1240 launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks) 1241 1242 if out.data_ptr() != out_backup.data_ptr(): 1243 # prepare_inputs has made a copy of out, copy its content back 1244 # to out_backup: 1245 out_backup.copy_(out_untiled.view(out_backup.shape)) 1246 1247 return out_backup 1248 1249 1250if has_triton(): 1251 import triton 1252 import triton.language as tl 1253 1254 @triton.jit 1255 def _sampled_addmm_kernel( 1256 alpha, 1257 beta, 1258 IS_BETA_ZERO: tl.constexpr, 1259 BLOCKSIZE_ROW: tl.constexpr, 1260 BLOCKSIZE_COL: tl.constexpr, 1261 k, 1262 TILE_K: tl.constexpr, 1263 values_ptr, 1264 values_batch_stride, 1265 values_nnz_stride, 1266 values_row_block_stride, 1267 values_col_block_stride, 1268 crow_indices_ptr, 1269 crow_indices_batch_stride, 1270 crow_indices_stride, 1271 col_indices_ptr, 1272 col_indices_batch_stride, 1273 col_indices_stride, 1274 mat1_ptr, 1275 mat1_batch_stride, 1276 mat1_tiled_row_stride, 1277 mat1_tiled_col_stride, 1278 mat1_row_block_stride, 1279 mat1_col_block_stride, 1280 mat2_ptr, 1281 mat2_batch_stride, 1282 mat2_tiled_row_stride, 1283 mat2_tiled_col_stride, 1284 mat2_row_block_stride, 1285 mat2_col_block_stride, 1286 acc_dtype: tl.constexpr, 1287 allow_tf32: tl.constexpr, 1288 ): 1289 batch_pid = tl.program_id(axis=1) 1290 row_block_pid = tl.program_id(axis=0) 1291 1292 crow_indices_offset_ptr = ( 1293 crow_indices_ptr 1294 + crow_indices_batch_stride * batch_pid 1295 + crow_indices_stride * row_block_pid 1296 ) 1297 nnz_offset = tl.load(crow_indices_offset_ptr) 1298 nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride) 1299 1300 # Compute nnz for the row with number row_block_pid. 1301 # If it is zero, skip the row. 1302 row_nnz = nnz_offset_next - nnz_offset 1303 if row_nnz == 0: 1304 return 1305 1306 row_block_arange = tl.arange(0, BLOCKSIZE_ROW) 1307 col_block_arange = tl.arange(0, BLOCKSIZE_COL) 1308 1309 # Pointers are set to the first block of the current row. 1310 values_block_ptrs = ( 1311 values_ptr 1312 + values_batch_stride * batch_pid 1313 + values_nnz_stride * nnz_offset 1314 + values_row_block_stride * row_block_arange[:, None] 1315 + values_col_block_stride * col_block_arange[None, :] 1316 ) 1317 1318 col_index_nnz_ptr = ( 1319 col_indices_ptr 1320 + col_indices_batch_stride * batch_pid 1321 + col_indices_stride * nnz_offset 1322 ) 1323 1324 # Advance mat1 to the current tiled row, ignore columns. 1325 mat1_block_ptrs = ( 1326 mat1_ptr 1327 + mat1_batch_stride * batch_pid 1328 + mat1_tiled_row_stride * row_block_pid 1329 + mat1_row_block_stride * row_block_arange[:, None] 1330 ) 1331 1332 # Advance mat2 in batch and block col dimension. 1333 mat2_block_ptrs = ( 1334 mat2_ptr 1335 + mat2_batch_stride * batch_pid 1336 + mat2_col_block_stride * col_block_arange[None, :] 1337 ) 1338 1339 k_tile_arange = tl.arange(0, TILE_K) 1340 for _ in range(row_nnz): 1341 acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype) 1342 1343 # find column block index 1344 col_block = tl.load(col_index_nnz_ptr) 1345 1346 for k_tile in range(0, k, TILE_K): 1347 k_offsets = k_tile + k_tile_arange 1348 mask_k = k_offsets < k 1349 1350 mat1_block = tl.load( 1351 mat1_block_ptrs + mat1_col_block_stride * k_offsets[None, :], 1352 mask=mask_k[None, :], 1353 other=0.0, 1354 ) 1355 1356 mat2_block = tl.load( 1357 mat2_block_ptrs 1358 + mat2_tiled_col_stride * col_block 1359 + mat2_row_block_stride * k_offsets[:, None], 1360 mask=mask_k[:, None], 1361 other=0.0, 1362 ) 1363 1364 acc_block += tl.dot( 1365 mat1_block, mat2_block, allow_tf32=allow_tf32, out_dtype=acc_dtype 1366 ) 1367 1368 if IS_BETA_ZERO: 1369 acc_block *= alpha 1370 else: 1371 acc_block = alpha * acc_block + beta * tl.load(values_block_ptrs) 1372 1373 # write result 1374 tl.store(values_block_ptrs, acc_block.to(values_ptr.dtype.element_ty)) 1375 1376 # advance val/col_index ptrs to the next block in the row. 1377 values_block_ptrs += values_nnz_stride 1378 col_index_nnz_ptr += col_indices_stride 1379 1380 @triton.jit 1381 def _bsr_strided_dense_rowspace_kernel( 1382 # values prologue 1383 values_ptr, 1384 values_batch_stride, 1385 values_nnz_stride, 1386 values_row_block_stride, 1387 values_col_block_stride, 1388 # values epilogue 1389 # crow_indices prologue 1390 crow_indices_ptr, 1391 crow_indices_batch_stride, 1392 crow_indices_stride, 1393 # crow_indices epilogue 1394 # col_indices prologue 1395 col_indices_ptr, 1396 col_indices_batch_stride, 1397 col_indices_stride, 1398 # col_indices epilogue 1399 # dense prologue 1400 dense_ptr, 1401 dense_batch_stride, 1402 dense_tiled_row_stride, 1403 dense_tiled_col_stride, 1404 dense_row_block_stride, 1405 dense_col_block_stride, 1406 # dense epilogue 1407 # output prologue 1408 output_ptr, 1409 output_batch_stride, 1410 output_tiled_row_stride, 1411 output_tiled_col_stride, 1412 output_row_block_stride, 1413 output_col_block_stride, 1414 # output epilogue 1415 # 1416 # gh-113754: Always keep all constexpr arguments at the end of 1417 # triton kernel arguments list because with triton 2.1 or 1418 # earlier non-contiguous outputs will corrupt CUDA state due 1419 # to a triton bug (fixed in openai/triton#2262). 1420 BLOCKSIZE_ROW: tl.constexpr, 1421 BLOCKSIZE_COL: tl.constexpr, 1422 acc_dtype: tl.constexpr, 1423 allow_tf32: tl.constexpr, 1424 GROUP_SIZE_ROW: tl.constexpr, 1425 ): 1426 batch_pid = tl.program_id(axis=2) 1427 row_block_pid = tl.program_id(axis=0) 1428 col_block_pid = tl.program_id(axis=1) 1429 n_block_rows = tl.num_programs(axis=0) 1430 n_block_cols = tl.num_programs(axis=1) 1431 1432 row_block_pid, col_block_pid = tl.swizzle2d( 1433 row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW 1434 ) 1435 1436 crow_indices_offset_ptr = ( 1437 crow_indices_ptr 1438 + crow_indices_batch_stride * batch_pid 1439 + crow_indices_stride * row_block_pid 1440 ) 1441 nnz_offset = tl.load(crow_indices_offset_ptr) 1442 nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride) 1443 1444 # Compute nnz for the row with number row_block_pid. 1445 # If it is zero, skip the row. 1446 row_nnz = nnz_offset_next - nnz_offset 1447 if row_nnz == 0: 1448 return 1449 1450 row_block_arange = tl.arange(0, BLOCKSIZE_ROW) 1451 col_block_arange = tl.arange(0, BLOCKSIZE_COL) 1452 1453 # Pointers are set to the first block of the current row. 1454 values_block_ptrs = ( 1455 values_ptr 1456 + values_batch_stride * batch_pid 1457 + values_nnz_stride * nnz_offset 1458 + values_row_block_stride * row_block_arange[:, None] 1459 + values_col_block_stride * col_block_arange[None, :] 1460 ) 1461 1462 # NOTE: dense is advanced into all dimensions but the tiled row one. 1463 # That will be advanced in the loop according to values in col_indices. 1464 dense_block_ptrs = ( 1465 dense_ptr 1466 + dense_batch_stride * batch_pid 1467 + dense_tiled_col_stride * col_block_pid 1468 + dense_row_block_stride * col_block_arange[:, None] 1469 + dense_col_block_stride * row_block_arange[None, :] 1470 ) 1471 1472 # Pointers are set to exact write-to locations 1473 output_ptrs = ( 1474 output_ptr 1475 + output_batch_stride * batch_pid 1476 + output_tiled_row_stride * row_block_pid 1477 + output_tiled_col_stride * col_block_pid 1478 + output_row_block_stride * row_block_arange[:, None] 1479 + output_col_block_stride * row_block_arange[None, :] 1480 ) 1481 1482 # Set pointer to the first nonzero element in the current row 1483 col_index_nnz_ptr = ( 1484 col_indices_ptr 1485 + col_indices_batch_stride * batch_pid 1486 + col_indices_stride * nnz_offset 1487 ) 1488 1489 output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype) 1490 for _ in range(row_nnz): 1491 values_block = tl.load(values_block_ptrs) 1492 1493 # find which row of dense needs to get loaded 1494 # for multiplication with values_block. 1495 dense_row_idx = tl.load(col_index_nnz_ptr) 1496 dense_block = tl.load( 1497 dense_block_ptrs + dense_tiled_row_stride * dense_row_idx 1498 ) 1499 1500 # do block mm 1501 output_acc_block += tl.dot( 1502 values_block, dense_block, allow_tf32=allow_tf32, out_dtype=acc_dtype 1503 ) 1504 1505 # move val/col_index ptrs to the next block in the row 1506 values_block_ptrs += values_nnz_stride 1507 col_index_nnz_ptr += col_indices_stride 1508 1509 # write back the result 1510 tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty)) 1511 1512 def _run_sampled_addmm_kernel( 1513 alpha, 1514 beta, 1515 is_beta_zero, 1516 blocksize, 1517 k, 1518 tile_k, 1519 values, 1520 crow_indices, 1521 col_indices, 1522 mat1, 1523 mat2, 1524 max_grid, 1525 ): 1526 n_batches = values.size(0) 1527 n_block_rows = crow_indices.size(-1) - 1 1528 1529 full_grid = (n_batches, n_block_rows) 1530 if max_grid is not None: 1531 grid_blocks = tuple(max_grid[:2][::-1]) + (None,) * (2 - len(max_grid[:2])) 1532 else: 1533 grid_blocks = None 1534 tensor_dims_map = { 1535 values: (0, None), 1536 crow_indices: (0, -1), 1537 col_indices: (0, None), 1538 mat1: (0, -4), 1539 mat2: (0, None), 1540 } 1541 if values.dtype in (torch.half, torch.bfloat16): 1542 acc_dtype = tl.float32 1543 allow_tf32 = True 1544 else: 1545 acc_dtype = tl.float64 1546 allow_tf32 = False 1547 1548 def kernel(grid, *sliced_tensors): 1549 _sampled_addmm_kernel[grid]( 1550 alpha, 1551 beta, 1552 is_beta_zero, 1553 *blocksize, 1554 k, 1555 tile_k, 1556 *ptr_stride_extractor(*sliced_tensors), 1557 acc_dtype=acc_dtype, 1558 allow_tf32=allow_tf32, 1559 num_stages=1, 1560 num_warps=4, 1561 ) 1562 1563 launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks) 1564 1565 def sampled_addmm( 1566 input: torch.Tensor, 1567 mat1: torch.Tensor, 1568 mat2: torch.Tensor, 1569 *, 1570 beta=1.0, 1571 alpha=1.0, 1572 out: Optional[torch.Tensor] = None, 1573 skip_checks: bool = False, 1574 max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, 1575 ): 1576 f_name = "sampled_addmm" 1577 1578 check_bsr_layout(f_name, input) 1579 input_broadcasted = broadcast_batch_dims_bsr(f_name, input, mat1, mat2) 1580 1581 if not skip_checks: 1582 check_device(f_name, mat1, input.device) 1583 check_device(f_name, mat2, input.device) 1584 if beta != 0.0 and input.dtype is torch.bool: 1585 check( 1586 False, 1587 f"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed.", 1588 ) 1589 if input.dtype is not torch.bool: 1590 check_dtype(f_name, mat1, input.dtype) 1591 check_dtype(f_name, mat2, input.dtype) 1592 else: 1593 check_dtype(f_name, mat1, mat2.dtype) 1594 check_mm_compatible_shapes(f_name, mat1, mat2) 1595 if out is not None: 1596 check_bsr_layout(f_name, out) 1597 check_device(f_name, out, mat1.device) 1598 check_dtype(f_name, out, input.dtype) 1599 check( 1600 out.shape == input_broadcasted.shape and out._nnz() == input._nnz(), 1601 f"{f_name}(): Expects `out` to be of shape {input_broadcasted.shape} " 1602 f"and with nnz equal to {input_broadcasted._nnz()} " 1603 f"but got out.shape = {out.shape} and out.nnz = {out._nnz()}", 1604 ) 1605 1606 if out is None: 1607 out = input_broadcasted.to(mat1.dtype, copy=True) 1608 else: 1609 out.copy_(input_broadcasted) 1610 1611 if out.numel() == 0 or out._nnz() == 0: 1612 return out 1613 1614 blocksize = out.values().shape[-2:] 1615 m = mat1.size(-2) 1616 n = mat2.size(-1) 1617 k = mat1.size(-1) 1618 1619 # NOTE: (m, 0) @ (0, n) == zeros(m, n) 1620 if alpha == 0.0 or k == 0: 1621 out.values().mul_(beta) 1622 return out 1623 1624 # prepare inputs by reshaping them to be kernel-compatible 1625 out_backup = out 1626 crow_indices, col_indices, values, mat1, mat2 = prepare_inputs(out, mat1, mat2) 1627 1628 mat1 = tile_to_blocksize(mat1, (blocksize[0], k)) 1629 mat2 = tile_to_blocksize(mat2, (k, blocksize[1])) 1630 tile_k = max(*blocksize) 1631 1632 _run_sampled_addmm_kernel( 1633 alpha, 1634 beta, 1635 beta == 0.0, 1636 blocksize, 1637 k, 1638 tile_k, 1639 values, 1640 crow_indices, 1641 col_indices, 1642 mat1, 1643 mat2, 1644 max_grid, 1645 ) 1646 1647 # If nnz x block strides are not the same in out_backup.values and values, 1648 # it means that out_backup.values and values are not the views of each other, 1649 # so we have to copy. 1650 if out_backup.values().stride()[-3:] != values.stride()[-3:]: 1651 out_backup.values().copy_(values.reshape(out_backup.values().shape)) 1652 return out_backup 1653 1654 def bsr_dense_mm( 1655 bsr: torch.Tensor, 1656 dense: torch.Tensor, 1657 *, 1658 out: Optional[torch.Tensor] = None, 1659 skip_checks: bool = False, 1660 max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, 1661 meta: Optional[dict] = None, 1662 ): 1663 f_name = "bsr_dense_mm" 1664 m, kl = bsr.shape[-2:] 1665 if not skip_checks: 1666 check_bsr_layout(f_name, bsr) 1667 check_device(f_name, bsr, dense.device) 1668 check_dtype(f_name, bsr, dense.dtype, (torch.int8,)) 1669 check_mm_compatible_shapes(f_name, bsr, dense) 1670 1671 n = dense.size(-1) 1672 row_block, col_block = bsr.values().shape[-2:] 1673 check_blocksize(f_name, (row_block, col_block)) 1674 check( 1675 not n % 16, 1676 f"{f_name}(): dense.size(-1) == {n} should be divisible by 16", 1677 ) 1678 else: 1679 kr, n = dense.shape[-2:] 1680 1681 original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense) 1682 1683 if out is not None and not skip_checks: 1684 expected_out_shape = original_batch_dims_broadcasted + (m, n) 1685 check( 1686 out.shape == expected_out_shape, 1687 "bsr_dense_mm(): `out` argument has wrong shape, " 1688 f"expected {expected_out_shape}, but got {out.shape}.", 1689 ) 1690 check( 1691 out.is_contiguous() or out.transpose(-2, -1).is_contiguous(), 1692 "bsr_dense_mm(): only row-major/col-major `out` arguments are supported, " 1693 "i.e. (out.is_contiguous() or out.transpose(-2, -1).is_contiguous()) " 1694 "should be True.", 1695 ) 1696 1697 # Allocate out 1698 if out is None: 1699 out = dense.new_empty(original_batch_dims_broadcasted + (m, n)) 1700 1701 # Short circuit if lhs is zero 1702 if bsr._nnz() == 0: 1703 return out.zero_() 1704 1705 # with beta==0, addmm ignores input content, so we can use out 1706 # as a placeholder for input because their shapes match: 1707 return bsr_dense_addmm(out, bsr, dense, alpha=1, beta=0, out=out) 1708 1709 @triton.jit 1710 def _bsr_softmax_kernel( 1711 crow_indices_ptr, 1712 crow_indices_batch_stride, 1713 crow_indices_stride, 1714 values_ptr, 1715 values_batch_stride, 1716 values_row_block_stride, 1717 values_nnz_col_block_stride, 1718 row_block, 1719 col_block, 1720 MAX_ROW_NNZ: tl.constexpr, 1721 TILE: tl.constexpr, 1722 ): 1723 batch_pid = tl.program_id(axis=2) 1724 row_block_offset_pid = tl.program_id(axis=1) 1725 row_block_pid = tl.program_id(axis=0) 1726 1727 crow_indices_offset_ptr = ( 1728 crow_indices_ptr 1729 + crow_indices_batch_stride * batch_pid 1730 + crow_indices_stride * row_block_pid 1731 ) 1732 nnz_offset = tl.load(crow_indices_offset_ptr) 1733 nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride) 1734 1735 # Compute nnz for the row with number row_block_pid. 1736 # If it is zero, skip the row. 1737 row_nnz = nnz_offset_next - nnz_offset 1738 if row_nnz == 0: 1739 return 1740 1741 row_arange = tl.arange(0, TILE) 1742 mask = row_arange < row_nnz * col_block 1743 1744 curr_row_values_ptrs = ( 1745 values_ptr 1746 + values_batch_stride * batch_pid 1747 + values_row_block_stride * row_block_offset_pid 1748 + nnz_offset * col_block 1749 ) 1750 1751 # find max in the row 1752 row_tile = tl.load( 1753 curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf") 1754 ).to(tl.float32) 1755 max_row_value = tl.max(row_tile, axis=0) 1756 for _ in range(TILE, MAX_ROW_NNZ, TILE): 1757 row_arange += TILE 1758 mask = row_arange < row_nnz * col_block 1759 row_tile = tl.load( 1760 curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf") 1761 ).to(tl.float32) 1762 curr_max_row_value = tl.max(row_tile, axis=0) 1763 max_row_value = tl.where( 1764 max_row_value > curr_max_row_value, max_row_value, curr_max_row_value 1765 ) 1766 1767 # find denominator for stable softmax 1768 num = tl.exp(row_tile - max_row_value) 1769 denom = tl.sum(num, axis=0) 1770 for _ in range(TILE, MAX_ROW_NNZ, TILE): 1771 row_arange -= TILE 1772 mask = row_arange < row_nnz * col_block 1773 row_tile = tl.load( 1774 curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf") 1775 ).to(tl.float32) 1776 num = tl.exp(row_tile - max_row_value) 1777 denom += tl.sum(num, axis=0) 1778 1779 # populate output 1780 tl.store( 1781 curr_row_values_ptrs + row_arange, 1782 (num / denom).to(values_ptr.dtype.element_ty), 1783 mask=mask, 1784 ) 1785 for _ in range(TILE, MAX_ROW_NNZ, TILE): 1786 row_arange += TILE 1787 mask = row_arange < row_nnz * col_block 1788 row_tile = tl.load( 1789 curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf") 1790 ).to(tl.float32) 1791 num = tl.exp(row_tile - max_row_value) 1792 tl.store( 1793 curr_row_values_ptrs + row_arange, 1794 (num / denom).to(values_ptr.dtype.element_ty), 1795 mask=mask, 1796 ) 1797 1798 def bsr_softmax(input, max_row_nnz=None): 1799 f_name = "bsr_softmax" 1800 1801 check_bsr_layout(f_name, input) 1802 check_dtype(f_name, input, input.dtype) 1803 1804 if input._nnz() == 0 or input.numel() == 0: 1805 return input.clone() 1806 1807 m, n = input.shape[-2:] 1808 nnz = input._nnz() 1809 row_block, col_block = input.values().shape[-2:] 1810 1811 if max_row_nnz is None: 1812 max_row_nnz = triton.next_power_of_2(n) 1813 else: 1814 max_row_nnz = triton.next_power_of_2(max_row_nnz) 1815 1816 crow_indices = input.crow_indices().unsqueeze(0).flatten(0, -2) 1817 # reshape values from 1818 # (b1, ..., bn, nnz, row_block, col_block) to 1819 # (b1 * ... * bn, row_block, nnz * col_block). 1820 # This simplifies batch dim manipulation and unlocks 1821 # the possibility to access all nnzs in any given row. 1822 if input.values().transpose(-3, -2).is_contiguous(): 1823 # Need to clone to avoid `contiguous` returning a view. 1824 values = input.values().clone() 1825 else: 1826 values = input.values() 1827 values = ( 1828 values.transpose(-3, -2) 1829 .contiguous() 1830 .unsqueeze(0) 1831 .flatten(0, -4) 1832 .reshape(-1, row_block, nnz * col_block) 1833 ) 1834 full_grid = (values.shape[0], row_block, m // row_block) 1835 grid_blocks = None 1836 tensor_dims_map = { 1837 # We span nnz number of blocks, not nnz + 1, 1838 # hence crow_indices[..., :-1] 1839 crow_indices[..., :-1]: (0, None, -1), 1840 values: (0, None, None), 1841 } 1842 1843 def kernel(grid, *sliced_tensors): 1844 _bsr_softmax_kernel[grid]( 1845 *ptr_stride_extractor(*sliced_tensors), 1846 row_block, 1847 col_block, 1848 max_row_nnz, 1849 # Triton's max numel is bounded by 2 ** 17. 1850 min(2**17, max_row_nnz), 1851 ) 1852 1853 launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks) 1854 1855 values = ( 1856 values.reshape(-1, row_block, nnz, col_block) 1857 .transpose(-3, -2) 1858 .reshape(*input.values().shape) 1859 ) 1860 1861 return torch.sparse_compressed_tensor( 1862 input.crow_indices().clone(), 1863 input.col_indices().clone(), 1864 values, 1865 size=input.shape, 1866 layout=input.layout, 1867 ) 1868 1869 def _scaled_dot_product_attention( 1870 query: torch.Tensor, 1871 key: torch.Tensor, 1872 value: torch.Tensor, 1873 attn_mask: Optional[torch.Tensor], 1874 dropout_p: float = 0.0, 1875 is_causal: bool = False, 1876 scale: Optional[float] = None, 1877 ): 1878 f_name = "_scaled_dot_product_attention" 1879 check(not is_causal, f"{f_name}(): is_causal == True is not supported.") 1880 check(attn_mask is not None, f"{f_name}(): attn_mask == None is not supported.") 1881 assert attn_mask is not None 1882 1883 check( 1884 attn_mask.layout == torch.sparse_bsr, 1885 f"{f_name}(): " 1886 f"attn_mask.layout must be {torch.sparse_bsr}, but got " 1887 f"attn_mask.layout == {attn_mask.layout}.", 1888 ) 1889 1890 check_device(f_name, key, query.device) 1891 check_device(f_name, value, query.device) 1892 check_device(f_name, attn_mask, query.device) 1893 1894 check_dtype(f_name, key, query.dtype) 1895 check_dtype(f_name, value, query.dtype) 1896 if attn_mask.dtype is not torch.bool: 1897 check_dtype(f_name, attn_mask, query.dtype) 1898 1899 sdpa = sampled_addmm( 1900 attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False 1901 ) 1902 if scale is None and query.size(-1) == 0 or scale == 0.0: 1903 check( 1904 False, 1905 f"{f_name}(): current value of scale == {scale} " 1906 "results in division by zero.", 1907 ) 1908 scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale 1909 sdpa.values().mul_(scale_factor) 1910 sdpa = bsr_softmax(sdpa) 1911 torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True) 1912 sdpa = bsr_dense_mm(sdpa, value) 1913 return sdpa 1914 1915 @triton.jit 1916 def _scatter_mm2_kernel( 1917 M: tl.constexpr, 1918 K: tl.constexpr, 1919 N: tl.constexpr, 1920 blocks_ptr, 1921 blocks_stride_P, 1922 blocks_stride_M, 1923 blocks_stride_K, 1924 others_ptr, 1925 others_stride_Q, 1926 others_stride_K, 1927 others_stride_N, 1928 accumulators_ptr, 1929 accumulators_stride_R, 1930 accumulators_stride_M, 1931 accumulators_stride_N, 1932 pq_offsets_ptr, 1933 pq_offsets_stride, 1934 pq_ptr, 1935 pq_stride_T, 1936 pq_stride_1, 1937 dot_out_dtype: tl.constexpr, 1938 TILE_M: tl.constexpr, 1939 TILE_N: tl.constexpr, 1940 allow_tf32: tl.constexpr, 1941 ): 1942 Ms = M // TILE_M 1943 Ns = N // TILE_N 1944 1945 pid_t = tl.program_id(axis=0) 1946 1947 pid = tl.program_id(axis=1) 1948 pid_m = pid // Ms 1949 pid_n = pid % Ms 1950 1951 rm = pid_m * TILE_M + tl.arange(0, TILE_M) 1952 rn = pid_n * TILE_N + tl.arange(0, TILE_N) 1953 rk = tl.arange(0, K) 1954 1955 A_ptr = blocks_ptr + ( 1956 rm[:, None] * blocks_stride_M + rk[None, :] * blocks_stride_K 1957 ) 1958 B_ptr = others_ptr + ( 1959 rk[:, None] * others_stride_K + rn[None, :] * others_stride_N 1960 ) 1961 1962 g0 = tl.load(pq_offsets_ptr + pid_t * pq_offsets_stride) 1963 g1 = tl.load(pq_offsets_ptr + (pid_t + 1) * pq_offsets_stride) 1964 1965 if g0 == g1: 1966 return 1967 1968 acc_block = tl.zeros((TILE_M, TILE_N), dtype=dot_out_dtype) 1969 1970 for i in range(g0, g1): 1971 p = tl.load(pq_ptr + i * pq_stride_T) 1972 q = tl.load(pq_ptr + i * pq_stride_T + pq_stride_1) 1973 A = tl.load(A_ptr + p * blocks_stride_P) 1974 B = tl.load(B_ptr + q * others_stride_Q) 1975 acc_block += tl.dot(A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) 1976 1977 C_ptr = ( 1978 accumulators_ptr 1979 + pid_t * accumulators_stride_R 1980 + ( 1981 rm[:, None] * accumulators_stride_M 1982 + rn[None, :] * accumulators_stride_N 1983 ) 1984 ) 1985 tl.store(C_ptr, acc_block.to(accumulators_ptr.dtype.element_ty)) 1986 1987 def _scatter_mm2( 1988 blocks: torch.Tensor, 1989 others: torch.Tensor, 1990 pq_offsets: torch.Tensor, 1991 pq_indices: torch.Tensor, 1992 accumulators: torch.Tensor, 1993 ): 1994 P, M, K = blocks.shape 1995 Q, _, N = others.shape 1996 R, _, _ = accumulators.shape 1997 1998 meta = dict( 1999 TILE_M=max(16, M // 4), TILE_N=max(16, N // 4), num_stages=1, num_warps=2 2000 ) 2001 2002 def grid(META): 2003 return ( 2004 pq_offsets.shape[0] - 1, 2005 triton.cdiv(M, META["TILE_M"]) * triton.cdiv(N, META["TILE_N"]), 2006 1, 2007 ) 2008 2009 dot_out_dtype = { 2010 torch.float16: tl.float32, 2011 torch.bfloat16: tl.float32, 2012 torch.float32: tl.float64, 2013 torch.float64: tl.float64, 2014 }[accumulators.dtype] 2015 if "allow_tf32" not in meta: 2016 meta.update(allow_tf32=dot_out_dtype == tl.float32) 2017 _scatter_mm2_kernel[grid]( 2018 M, 2019 K, 2020 N, 2021 blocks, 2022 blocks.stride(0), 2023 blocks.stride(1), 2024 blocks.stride(2), 2025 others, 2026 others.stride(0), 2027 others.stride(1), 2028 others.stride(2), 2029 accumulators, 2030 accumulators.stride(0), 2031 accumulators.stride(1), 2032 accumulators.stride(2), 2033 pq_offsets, 2034 pq_offsets.stride(0), 2035 pq_indices, 2036 pq_indices.stride(0), 2037 pq_indices.stride(1), 2038 dot_out_dtype=dot_out_dtype, 2039 **meta, 2040 ) 2041 2042 @triton.jit 2043 def _scatter_mm6_kernel( 2044 nbatches, 2045 Ms, 2046 Ks: tl.constexpr, 2047 N, 2048 blocks_ptr, 2049 blocks_stride_P, 2050 blocks_stride_M, 2051 blocks_stride_K, 2052 others_ptr, 2053 others_stride_B, 2054 others_stride_K, 2055 others_stride_N, 2056 accumulators_ptr, 2057 accumulators_stride_B, 2058 accumulators_stride_M, 2059 accumulators_stride_N, 2060 c_indices_ptr, 2061 r_offsets_ptr, 2062 p_offsets_ptr, 2063 q_offsets_ptr, 2064 is_compressed: tl.constexpr, 2065 dot_out_dtype: tl.constexpr, 2066 SPLIT_N: tl.constexpr, 2067 TILE_M: tl.constexpr, 2068 TILE_N: tl.constexpr, 2069 GROUP_SIZE: tl.constexpr, 2070 allow_tf32: tl.constexpr, 2071 ): 2072 Ns = N // SPLIT_N 2073 BLOCKS_M = Ms // TILE_M 2074 BLOCKS_N = Ns // TILE_N 2075 2076 pid_t_ = tl.program_id(axis=0) 2077 pid = tl.program_id(axis=1) 2078 pid_b = pid_t_ % nbatches 2079 pid_t = pid_t_ // nbatches 2080 2081 num_pid_in_group = GROUP_SIZE * BLOCKS_N 2082 group_id = pid // num_pid_in_group 2083 first_pid_m = group_id * GROUP_SIZE 2084 group_size_m = min(BLOCKS_M - first_pid_m, GROUP_SIZE) 2085 pid_m = first_pid_m + (pid % group_size_m) 2086 pid_n = (pid % num_pid_in_group) // group_size_m 2087 2088 rm = pid_m * TILE_M + tl.arange(0, TILE_M) 2089 rn = pid_n * TILE_N + tl.arange(0, TILE_N) 2090 rk = tl.arange(0, Ks) 2091 A_ptr = blocks_ptr + ( 2092 rm[:, None] * blocks_stride_M + rk[None, :] * blocks_stride_K 2093 ) 2094 B_ptr = ( 2095 others_ptr 2096 + pid_b * others_stride_B 2097 + (rk[:, None] * others_stride_K + rn[None, :] * others_stride_N) 2098 ) 2099 2100 # When is_compressed is True, r is the only variable that 2101 # depends on pid_t. This property allows sorting r values 2102 # before calling the kernel. The sorting of r is equivalent to 2103 # defining swizzle operator outside of the kernel. 2104 r = tl.load(r_offsets_ptr + pid_t) 2105 2106 if is_compressed: 2107 m = (r // N) // Ms 2108 n = (r % N) // Ns 2109 r0 = tl.load(c_indices_ptr + m) 2110 r1 = tl.load(c_indices_ptr + m + 1) 2111 g0 = n * r1 + (SPLIT_N - n) * r0 2112 nnz = r1 - r0 2113 else: 2114 g0 = tl.load(c_indices_ptr + pid_t) 2115 g1 = tl.load(c_indices_ptr + pid_t + 1) 2116 nnz = g1 - g0 2117 2118 q_ptr = q_offsets_ptr + g0 2119 acc_block = tl.zeros((TILE_M, TILE_N), dtype=dot_out_dtype) 2120 2121 if is_compressed: 2122 A_ptr += r0 * blocks_stride_P # type: ignore[possibly-undefined] 2123 for _ in range(nnz): 2124 q = tl.load(q_ptr) 2125 B = tl.load(B_ptr + q) 2126 A = tl.load(A_ptr) 2127 acc_block += tl.dot( 2128 A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32 2129 ) 2130 A_ptr += blocks_stride_P 2131 q_ptr += 1 2132 else: 2133 p_ptr = p_offsets_ptr + g0 2134 for _ in range(nnz): 2135 q = tl.load(q_ptr) 2136 B = tl.load(B_ptr + q) 2137 p = tl.load(p_ptr) 2138 A = tl.load(A_ptr + p * blocks_stride_P) 2139 p_ptr += 1 2140 q_ptr += 1 2141 acc_block += tl.dot( 2142 A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32 2143 ) 2144 2145 C_ptr = ( 2146 accumulators_ptr 2147 + r 2148 + pid_b * accumulators_stride_B 2149 + ( 2150 rm[:, None] * accumulators_stride_M 2151 + rn[None, :] * accumulators_stride_N 2152 ) 2153 ) 2154 tl.store(C_ptr, acc_block.to(accumulators_ptr.dtype.element_ty)) 2155 2156 def _scatter_mm6( 2157 blocks: torch.Tensor, 2158 others: torch.Tensor, 2159 c_indices: torch.Tensor, 2160 r_offsets: torch.Tensor, 2161 p_offsets: torch.Tensor, 2162 q_offsets: torch.Tensor, 2163 meta: dict, 2164 accumulators: torch.Tensor, 2165 force_contiguous: bool = True, 2166 ): 2167 SPLIT_N = meta["SPLIT_N"] 2168 P, Ms, Ks = blocks.shape 2169 B, K_, N = others.shape 2170 B_, M, N_ = accumulators.shape 2171 assert N_ == N 2172 Ns = N // SPLIT_N 2173 assert B_ == B 2174 2175 def grid(META): 2176 return ( 2177 r_offsets.shape[0] * B, 2178 triton.cdiv(Ms, META["TILE_M"]) * triton.cdiv(Ns, META["TILE_N"]), 2179 ) 2180 2181 dot_out_dtype = { 2182 torch.float16: tl.float32, 2183 torch.bfloat16: tl.float32, 2184 torch.float32: tl.float64, 2185 torch.float64: tl.float64, 2186 }[accumulators.dtype] 2187 if "allow_tf32" not in meta: 2188 meta.update(allow_tf32=dot_out_dtype == tl.float32) 2189 2190 assert c_indices.stride(0) == 1 2191 assert r_offsets.stride(0) == 1 2192 assert p_offsets.stride(0) == 1 2193 assert q_offsets.stride(0) == 1 2194 2195 # Re non-contiguous tensor arguments. Sometimes triton kernel 2196 # launches may fail with 2197 # 2198 # RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered 2199 # 2200 # that appears to be case when the size of a non-contiguous 2201 # tensor argument is larger than a certain threshold. Could 2202 # this be related to shared memory or L1 cache size of a GPU 2203 # card? In anycase, ensuring that tensor arguments are 2204 # contiguous seems to avoid the above exception. So, in the 2205 # following we'll always convert tensor arguments to 2206 # C-contiguous tensors. 2207 2208 if force_contiguous: 2209 blocks = blocks.contiguous() 2210 others = others.contiguous() 2211 if not accumulators.is_contiguous(): 2212 accumulators_ = accumulators.contiguous() 2213 else: 2214 accumulators_ = accumulators 2215 else: 2216 accumulators_ = accumulators 2217 2218 _scatter_mm6_kernel[grid]( 2219 B, 2220 Ms, 2221 Ks, 2222 N, 2223 blocks, 2224 blocks.stride(0), 2225 blocks.stride(1), 2226 blocks.stride(2), 2227 others, 2228 others.stride(0), 2229 others.stride(1), 2230 others.stride(2), 2231 accumulators_, 2232 accumulators_.stride(0), 2233 accumulators_.stride(1), 2234 accumulators_.stride(2), 2235 c_indices, 2236 r_offsets, 2237 p_offsets, 2238 q_offsets, 2239 dot_out_dtype=dot_out_dtype, 2240 **meta, 2241 ) 2242 2243 if force_contiguous and not accumulators.is_contiguous(): 2244 accumulators.copy_(accumulators_) 2245 2246 @triton.jit 2247 def _bsr_strided_addmm_kernel( 2248 # values prologue 2249 values_ptr, 2250 values_batch_stride, 2251 values_nnz_stride, 2252 values_row_block_stride, 2253 values_col_block_stride, 2254 # values epilogue 2255 # crow_indices prologue 2256 crow_indices_ptr, 2257 crow_indices_batch_stride, 2258 crow_indices_stride, 2259 # crow_indices epilogue 2260 # col_indices prologue 2261 col_indices_ptr, 2262 col_indices_batch_stride, 2263 col_indices_stride, 2264 # col_indices epilogue 2265 # input prologue 2266 input_ptr, 2267 input_batch_stride, 2268 input_tiled_row_stride, 2269 input_tiled_col_stride, 2270 input_row_block_stride, 2271 input_col_block_stride, 2272 # input epilogue 2273 # dense prologue 2274 dense_ptr, 2275 dense_batch_stride, 2276 dense_tiled_row_stride, 2277 dense_tiled_col_stride, 2278 dense_row_block_stride, 2279 dense_col_block_stride, 2280 # dense epilogue 2281 # output prologue 2282 output_ptr, 2283 output_batch_stride, 2284 output_tiled_row_stride, 2285 output_tiled_col_stride, 2286 output_row_block_stride, 2287 output_col_block_stride, 2288 # output epilogue 2289 beta, 2290 alpha, 2291 beta_is_one: tl.constexpr, 2292 beta_is_nonzero: tl.constexpr, 2293 alpha_is_one: tl.constexpr, 2294 BLOCKSIZE_ROW: tl.constexpr, 2295 BLOCKSIZE_COL: tl.constexpr, 2296 BLOCKSIZE_INNER: tl.constexpr, 2297 acc_dtype: tl.constexpr, 2298 allow_tf32: tl.constexpr, 2299 GROUP_SIZE_ROW: tl.constexpr, 2300 SPLIT_N: tl.constexpr, 2301 ): 2302 batch_pid = tl.program_id(axis=2) 2303 row_block_pid = tl.program_id(axis=0) 2304 col_block_pid = tl.program_id(axis=1) 2305 n_block_rows = tl.num_programs(axis=0) 2306 n_block_cols = tl.num_programs(axis=1) 2307 2308 row_block_pid, col_block_pid = tl.swizzle2d( 2309 row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW 2310 ) 2311 2312 crow_indices_offset_ptr = ( 2313 crow_indices_ptr 2314 + crow_indices_batch_stride * batch_pid 2315 + crow_indices_stride * row_block_pid 2316 ) 2317 nnz_offset = tl.load(crow_indices_offset_ptr) 2318 nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride) 2319 2320 # Compute nnz for the row with number row_block_pid. 2321 row_nnz = nnz_offset_next - nnz_offset 2322 2323 row_block_arange = tl.arange(0, BLOCKSIZE_ROW) 2324 inner_block_arange = tl.arange(0, BLOCKSIZE_INNER) 2325 col_block_arange = tl.arange(0, BLOCKSIZE_COL) 2326 2327 if beta_is_nonzero: 2328 # Pointers are set to exact write-to locations 2329 input_ptrs = ( 2330 input_ptr 2331 + input_batch_stride * batch_pid 2332 + input_tiled_row_stride * row_block_pid 2333 + input_tiled_col_stride * col_block_pid 2334 + input_row_block_stride * row_block_arange[:, None] 2335 + input_col_block_stride * col_block_arange[None, :] 2336 ) 2337 2338 # Pointers are set to the first block of the current row. 2339 values_block_ptrs = ( 2340 values_ptr 2341 + values_batch_stride * batch_pid 2342 + values_nnz_stride * nnz_offset 2343 + values_row_block_stride * row_block_arange[:, None] 2344 + values_col_block_stride * inner_block_arange[None, :] 2345 ) 2346 2347 # NOTE: dense is advanced into all dimensions but the tiled row one. 2348 # That will be advanced in the loop according to values in col_indices. 2349 dense_block_ptrs = ( 2350 dense_ptr 2351 + dense_batch_stride * batch_pid 2352 + dense_tiled_col_stride * col_block_pid 2353 + dense_row_block_stride * inner_block_arange[:, None] 2354 + dense_col_block_stride * col_block_arange[None, :] 2355 ) 2356 2357 # Pointers are set to exact write-to locations 2358 output_ptrs = ( 2359 output_ptr 2360 + output_batch_stride * batch_pid 2361 + output_tiled_row_stride * row_block_pid 2362 + output_tiled_col_stride * col_block_pid 2363 + output_row_block_stride * row_block_arange[:, None] 2364 + output_col_block_stride * col_block_arange[None, :] 2365 ) 2366 2367 # Set pointer to the first nonzero element in the current row 2368 col_index_nnz_ptr = ( 2369 col_indices_ptr 2370 + col_indices_batch_stride * batch_pid 2371 + col_indices_stride * nnz_offset 2372 ) 2373 2374 # alpha is never 0 2375 if beta_is_nonzero: 2376 output_acc_block = tl.load(input_ptrs).to(acc_dtype) # type: ignore[possibly-undefined] 2377 if not (beta_is_one and alpha_is_one): 2378 beta_alpha = beta / alpha 2379 output_acc_block *= beta_alpha 2380 else: 2381 output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype) 2382 2383 for _ in range(row_nnz): 2384 values_block = tl.load(values_block_ptrs) 2385 2386 # find which row of dense needs to get loaded 2387 # for multiplication with values_block. 2388 dense_row_idx = tl.load(col_index_nnz_ptr) 2389 dense_block = tl.load( 2390 dense_block_ptrs + dense_tiled_row_stride * dense_row_idx 2391 ) 2392 2393 # do block mm 2394 output_acc_block += tl.dot( 2395 values_block, dense_block, allow_tf32=allow_tf32, out_dtype=acc_dtype 2396 ) 2397 2398 # move val/col_index ptrs to the next block in the row 2399 values_block_ptrs += values_nnz_stride 2400 col_index_nnz_ptr += col_indices_stride 2401 2402 if not alpha_is_one: 2403 output_acc_block *= alpha 2404 2405 # write back the result 2406 tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty)) 2407 2408else: 2409 bsr_softmax = None # type: ignore[assignment] 2410 bsr_dense_mm = None # type: ignore[assignment] 2411 sampled_addmm = None # type: ignore[assignment] 2412 _scaled_dot_product_attention = None # type: ignore[assignment] 2413 _scatter_mm2 = None # type: ignore[assignment] 2414 _scatter_mm6 = None # type: ignore[assignment] 2415 _bsr_strided_addmm_kernel = None # type: ignore[assignment] 2416