xref: /aosp_15_r20/external/pytorch/torch/sparse/_triton_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import math
4import os
5import weakref
6from functools import lru_cache
7from typing import Optional, Tuple
8
9import torch
10from torch._dynamo.utils import warn_once
11from torch.utils._triton import has_triton
12
13from ._triton_ops_meta import get_meta
14
15
16TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE = int(
17    os.getenv("TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE", 2)
18)
19
20
21def check(cond, msg):
22    if not cond:
23        raise ValueError(msg)
24
25
26def check_bsr_layout(f_name, t):
27    check(
28        t.layout == torch.sparse_bsr,
29        f"{f_name}(): only BSR sparse format is supported for the sparse argument.",
30    )
31
32
33def check_device(f_name, t, device):
34    check(
35        t.device == device and t.device.type == "cuda",
36        f"{f_name}(): all inputs are expected to be on the same GPU device.",
37    )
38
39
40def check_mm_compatible_shapes(f_name, lhs, rhs):
41    check(
42        lhs.dim() >= 2 and rhs.dim() >= 2,
43        f"{f_name}(): all inputs involved in the matrix product are expected to be at least 2D, "
44        f"but got lhs.dim() == {lhs.dim()} and rhs.dim() == {rhs.dim()}.",
45    )
46
47    m, kl = lhs.shape[-2:]
48    kr, n = rhs.shape[-2:]
49
50    check(
51        kl == kr,
52        f"{f_name}(): arguments' sizes involved in the matrix product are not compatible for matrix multiplication, "
53        f"got lhs.shape[-1] == {kl} which is not equal to rhs.shape[-2] == {kr}.",
54    )
55
56
57def check_dtype(f_name, t, dtype, *additional_dtypes):
58    check(
59        t.dtype == dtype
60        and t.dtype
61        in ((torch.half, torch.bfloat16, torch.float) + tuple(*additional_dtypes)),
62        f"{f_name}(): all inputs are expected to be of the same dtype "
63        f"and one of (half, bfloat16, float32) or {additional_dtypes}, "
64        f"but got dtype == {t.dtype}.",
65    )
66
67
68def check_blocksize(f_name, blocksize):
69    assert len(blocksize) == 2
70
71    def is_power_of_two(v):
72        return not (v & (v - 1))
73
74    def is_compatible_blocksize(b):
75        res = True
76        for blocksize in b:
77            # Triton loads only blocks which are at least 16 and powers of 2.
78            res = (blocksize >= 16 and is_power_of_two(blocksize)) and res
79        return res
80
81    check(
82        is_compatible_blocksize(blocksize),
83        f"{f_name}(): sparse inputs' blocksize ({blocksize[0]}, {blocksize[1]}) "
84        "should be at least 16 and a power of 2 in each dimension.",
85    )
86
87
88def make_triton_contiguous(t):
89    """Return input as a triton-contiguous tensor.
90
91    A triton-contiguous tensor is defined as a tensor that has strides
92    with minimal value equal to 1.
93
94    While triton kernels support triton-non-contiguous tensors (all
95    strides being greater than 1 or having 0 strides) arguments, a
96    considerable slow-down occurs because tensor data is copied
97    element-wise rather than chunk-wise.
98    """
99    if min(t.stride()) != 1:
100        # TODO: investigate if contiguity along other axes than the
101        # last one can be beneficial for performance
102        return t.contiguous()
103    else:
104        return t
105
106
107def broadcast_batch_dims(f_name, *tensors):
108    try:
109        return torch.broadcast_shapes(*(t.shape[:-2] for t in tensors))
110    except Exception:
111        check(False, f"{f_name}(): inputs' batch dimensions are not broadcastable!")
112
113
114def slicer(dim, slice_range, *tensors):
115    for t in tensors:
116        slices = [slice(None)] * t.dim()
117        slices[dim] = slice_range
118        yield t[slices]
119
120
121def multidim_slicer(dims, slices, *tensors):
122    for t in tensors:
123        s = [slice(None)] * t.dim()
124        for d, d_slice in zip(dims, slices):
125            if d is not None:
126                s[d] = d_slice
127        yield t[s]
128
129
130def ptr_stride_extractor(*tensors):
131    for t in tensors:
132        yield t
133        yield from t.stride()
134
135
136def grid_partitioner(full_grid, grid_blocks, tensor_dims_map):
137    assert 0 <= len(full_grid) <= 3
138    assert 0 <= len(grid_blocks) <= 3
139
140    import itertools
141
142    def generate_grid_points():
143        for fg, mg in zip(full_grid, grid_blocks):
144            yield range(0, fg, mg)
145
146    def generate_sliced_tensors(slices):
147        for t, t_dims in tensor_dims_map.items():
148            yield next(multidim_slicer(t_dims, slices, t))
149
150    for grid_point in itertools.product(*generate_grid_points()):
151        grid = [
152            min(fg - gp, mg) for fg, gp, mg in zip(full_grid, grid_point, grid_blocks)
153        ]
154        slices = [slice(gp, gp + g) for gp, g in zip(grid_point, grid)]
155        # grid_points are iterated in a "contiguous" order, i.e.
156        # left dimensions traversed slower than right dimensions.
157        # This order is reversed for CUDA grids.
158        yield grid[::-1], *generate_sliced_tensors(slices)
159
160
161def launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks=None):
162    # cuda_max_grid = (2 ** 31 - 1, 2 ** 16 - 1, 2 ** 16 - 1)
163    cuda_max_grid = (2147483647, 65535, 65535)[::-1]
164    if grid_blocks is None:
165        grid_blocks = cuda_max_grid
166    else:
167
168        def valid_grid_dim(g, mg):
169            if g is None:
170                return mg
171            else:
172                # grid must be at least 1 and no greater than mg
173                return max(1, min(g, mg))
174
175        grid_blocks = tuple(
176            valid_grid_dim(g, mg) for g, mg in zip(grid_blocks, cuda_max_grid)
177        )  # type: ignore[assignment]
178
179    for grid, *sliced_tensors in grid_partitioner(
180        full_grid, grid_blocks, tensor_dims_map
181    ):
182        kernel(grid, *sliced_tensors)
183
184
185def prepare_inputs(bsr, *dense_tensors):
186    # Introduce fake batch dimension if not present for convenience.
187    crow_indices = bsr.crow_indices().unsqueeze(0)
188    col_indices = bsr.col_indices().unsqueeze(0)
189    values = make_triton_contiguous(bsr.values().unsqueeze(0))
190    tensors = [make_triton_contiguous(t.unsqueeze(0)) for t in dense_tensors]
191
192    # Compute broadcasted batch dimension
193    batch_dims_broadcasted = torch.broadcast_shapes(
194        values.shape[:-3], *(t.shape[:-2] for t in tensors)
195    )
196
197    # Broadcast batch dimensions and squash.
198    # The result can be either a view or a copy.
199    def batch_broadcast_and_squash(t, batch_dims, invariant_dims):
200        return t.broadcast_to(batch_dims + invariant_dims).flatten(
201            0, len(batch_dims) - 1
202        )
203
204    crow_indices = batch_broadcast_and_squash(
205        crow_indices, batch_dims_broadcasted, (-1,)
206    )
207
208    col_indices = batch_broadcast_and_squash(col_indices, batch_dims_broadcasted, (-1,))
209    values = batch_broadcast_and_squash(
210        values, batch_dims_broadcasted, values.shape[-3:]
211    )
212    tensors = [
213        batch_broadcast_and_squash(t, batch_dims_broadcasted, t.shape[-2:])
214        for t in tensors
215    ]
216
217    return crow_indices, col_indices, values, *tensors
218
219
220def broadcast_batch_dims_bsr(f_name, bsr, *tensors):
221    batch_shape = broadcast_batch_dims(f_name, bsr, *tensors)
222
223    crow_indices = bsr.crow_indices().broadcast_to(batch_shape + (-1,))
224    col_indices = bsr.col_indices().broadcast_to(batch_shape + (-1,))
225    values = bsr.values().broadcast_to(batch_shape + bsr.values().shape[-3:])
226    size = batch_shape + bsr.shape[-2:]
227    return torch.sparse_compressed_tensor(
228        crow_indices, col_indices, values, size=size, layout=bsr.layout
229    )
230
231
232# NOTE: this function will ALWAYS create a view
233def tile_to_blocksize(t, blocksize):
234    *rest, m, n = t.shape
235    new_shape = rest + [
236        m // blocksize[0],
237        blocksize[0],
238        n // blocksize[1],
239        blocksize[1],
240    ]
241    # using .view instead of .reshape to ensure that the result is
242    # indeed a view:
243    return t.view(new_shape).transpose(-3, -2)
244
245
246def as1Dbatch(tensor):
247    """Return tensor as 3D tensor by either prepending new dimensions to
248    the tensor shape (when ``tensor.ndim < 3``), or by collapsing
249    starting dimensions into the first dimension (when ``tensor.ndim >
250    3``).
251    """
252    while tensor.ndim < 3:
253        tensor = tensor.unsqueeze(0)
254    if tensor.ndim > 3:
255        tensor = tensor.flatten(0, tensor.ndim - 3)
256    assert tensor.ndim == 3, tensor.shape
257    return tensor
258
259
260def scatter_mm(blocks, others, indices_data, *, accumulators=None):
261    """Scattered matrix multiplication of tensors.
262
263    A scattered matrix multiplication is defined as a series of matrix
264    multiplications applied to input tensors according to the input
265    and output mappings specified by indices data.
266
267    The following indices data formats are supported for defining a
268    scattered matrix multiplication operation (:attr:`indices_data[0]`
269    holds the name of the indices data format as specified below):
270
271    - ``"scatter_mm"`` - matrix multiplications scattered in batches
272      of tensors.
273
274      If :attr:`blocks` is a :math:`(* \times M \times K) tensor,
275      :attr:`others` is a :math:`(* \times K \times N)` tensor,
276      :attr:`accumulators` is a :math:`(* \times M \times N)` tensor,
277      and :attr:`indices = indices_data['indices']` is a :math:`(*
278      \times 3)` tensor, then the operation is equivalent to the
279      following code::
280
281        c_offsets, pq = indices_data[1:]
282        for r in range(len(c_offsets) - 1):
283            for g in range(c_offsets[r], c_offsets[r + 1]):
284                p, q = pq[g]
285                accumulators[r] += blocks[p] @ others[q]
286
287    - ``"bsr_strided_mm"`` - matrix multiplications scattered in
288      batches of tensors and a tensor.
289
290      If :attr:`blocks` is a :math:`(Ms \times Ks) tensor,
291      :attr:`others` is a :math:`(* \times K \times N)` tensor,
292      :attr:`accumulators` is a :math:`(* \times M \times N)` tensor, then
293      the operation is equivalent to the following code::
294
295        c_indices, r_offsets, p_offsets, q_offsets, meta = indices_data[1:]
296        for b in range(nbatches):
297            for i, r in enumerate(r_offsets):
298                r0, r1 = divmod(r, N)
299                acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns]
300                for g in range(c_indices[i], c_indices[i+1]):
301                    p = p_offsets[g]
302                    q0, q1 = divmod(q_offsets[g], N)
303                    acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns]
304
305      where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are
306      integer multiples of ``Ms`` and ``Ks``, respectively.
307
308    - ``"bsr_strided_mm_compressed"`` - matrix multiplications
309      scattered in batches of tensors and a tensor. A memory and
310      processor efficient version of ``"bsr_strided_mm"`` format.  If
311      :attr:`blocks` is a :math:`(Ms \times Ks) tensor, :attr:`others`
312      is a :math:`(* \times K \times N)` tensor, :attr:`accumulators`
313      is a :math:`(* \times M \times N)` tensor, then the operation is
314      equivalent to the following code::
315
316        c_indices, r_offsets, q_offsets, meta = indices_data[1:]
317        for b in range(nbatches):
318            for r in r_offsets:
319                m = (r // N) // Ms
320                n = (r % N) // Ns
321                r0, r1 = divmod(r, N)
322                c0, c1 = c_indices[m], c_indices[m + 1]
323                acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns]
324                for i, p in enumerate(range(c0, c1)):
325                    q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i]
326                    q0, q1 = divmod(q, N)
327                    acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns]
328
329      where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are
330      integer multiples of ``Ms`` and ``Ks``, respectively.
331
332      Notice that the order of ``r_offsets`` items can be arbitrary;
333      this property enables defining swizzle operators via
334      rearrangements of ``r_offsets`` items..
335
336    Auxilary functions are provided for pre-computing
337    :attr:`indices_data`. For example,
338    :func:`bsr_scatter_mm_indices_data` is used to define indices data
339    for matrix multiplication of BSR and strided tensors.
340
341    Parameters
342    ----------
343    blocks (Tensor): a 3-D tensor of first matrices to be multiplied
344
345    others (Tensor): a tensor of second matrices to be multiplied. If
346      ``indices_data[0]=="scatter_mm"``, the tensor is a 1-D batch
347      tensor of second input matrices to be multiplied. Otherwise, the
348      second input matrices are slices of the :attr:`others` tensor.
349    indices_data (tuple): a format data that defines the inputs and
350      outputs of scattered matrix multiplications.
351
352    Keyword arguments
353    -----------------
354
355    accumulators (Tensor, optional): a tensor of matrix product
356      accumulators. If ``indices_data[0]=="scatter_mm"``, the tensor
357      is a 1-D batch tensor of output matrices. Otherwise, output
358      matrices are slices of the :attr:`accumulators` tensor.
359    """
360    indices_format = indices_data[0]
361
362    assert blocks.ndim == 3
363    P, Ms, Ks = blocks.shape
364
365    if indices_format == "scatter_mm":
366        c_offsets, pq = indices_data[1:]
367
368        assert others.ndim == 3
369        Q, Ks_, Ns = others.shape
370        assert Ks == Ks_
371
372        if accumulators is None:
373            R = c_offsets.shape[0] - 1
374            accumulators = torch.zeros(
375                (R, Ms, Ns), dtype=blocks.dtype, device=blocks.device
376            )
377        else:
378            R, Ms_, Ns_ = accumulators.shape
379            assert Ms_ == Ms
380            assert Ns_ == Ns
381
382        if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm2 is None:
383            for r in range(c_offsets.shape[0] - 1):
384                g0 = c_offsets[r]
385                g1 = c_offsets[r + 1]
386                for g in range(g0, g1):
387                    p, q = pq[g]
388                    accumulators[r] += blocks[p] @ others[q]
389        else:
390            _scatter_mm2(blocks, others, c_offsets, pq, accumulators)
391        return accumulators
392
393    elif indices_format == "bsr_strided_mm":
394        others_shape = others.shape
395        others = as1Dbatch(others)
396
397        B, K, N = others.shape
398        assert K % Ks == 0
399
400        c_indices, r_offsets, p_offsets, q_offsets, meta = indices_data[1:]
401        SPLIT_N = meta["SPLIT_N"]
402
403        if accumulators is None:
404            M = Ms + (r_offsets.max().item() + 1) // N
405            accumulators = torch.zeros(
406                (*others_shape[:-2], M, N), dtype=blocks.dtype, device=blocks.device
407            )
408        else:
409            M, N_ = accumulators.shape[-2:]
410            assert N_ == N
411
412        accumulators_shape = accumulators.shape
413        accumulators = as1Dbatch(accumulators)
414
415        Ns = N // SPLIT_N
416
417        if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm6 is None:
418            accumulators.zero_()
419            for b in range(B):
420                for r in range(r_offsets.shape[0]):
421                    r_ = r_offsets[r].item()
422                    g0 = c_indices[r].item()
423                    g1 = c_indices[r + 1].item()
424                    r0, r1 = divmod(r_, N)
425                    acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns]
426                    for g in range(g0, g1):
427                        p, q = p_offsets[g], q_offsets[g]
428                        q0, q1 = divmod(q.item(), N)
429                        acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns]
430        else:
431            _scatter_mm6(
432                blocks,
433                others,
434                c_indices,
435                r_offsets,
436                p_offsets,
437                q_offsets,
438                meta,
439                accumulators,
440            )
441        return accumulators.view(accumulators_shape)
442
443    elif indices_format == "bsr_strided_mm_compressed":
444        others_shape = others.shape
445        others = as1Dbatch(others)
446
447        B, K, N = others.shape
448        assert K % Ks == 0
449
450        c_indices, r_offsets, q_offsets, meta = indices_data[1:]
451        SPLIT_N = meta["SPLIT_N"]
452
453        if accumulators is None:
454            M = Ms + (r_offsets.max().item() + 1) // N
455            accumulators = torch.zeros(
456                (*others_shape[:-2], M, N), dtype=blocks.dtype, device=blocks.device
457            )
458        else:
459            M, N_ = accumulators.shape[-2:]
460            assert N_ == N
461
462        accumulators_shape = accumulators.shape
463        accumulators = as1Dbatch(accumulators)
464
465        Ns = N // SPLIT_N
466
467        if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm6 is None:
468            for b in range(B):
469                for j in range(len(r_offsets)):
470                    r0, r1 = divmod(r_offsets[j].item(), N)
471                    m = r0 // Ms
472                    n = r1 // Ns
473                    c0 = c_indices[m].item()
474                    c1 = c_indices[m + 1].item()
475                    acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns]
476                    for i, p in enumerate(range(c0, c1)):
477                        q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i].item()
478                        q0, q1 = divmod(q, N)
479                        acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns]
480        else:
481            p_offsets = torch.empty(
482                (0,), dtype=q_offsets.dtype, device=q_offsets.device
483            )
484            _scatter_mm6(
485                blocks,
486                others,
487                c_indices,
488                r_offsets,
489                p_offsets,
490                q_offsets,
491                meta,
492                accumulators,
493            )
494        return accumulators.view(accumulators_shape)
495
496    else:
497        raise NotImplementedError(indices_format)
498
499
500def scatter_mm_meta(
501    M,
502    K,
503    N,
504    Ms,
505    Ks,
506    GROUP_SIZE=None,
507    TILE_M=None,
508    TILE_N=None,
509    SPLIT_N=None,
510    num_warps=None,
511    num_stages=None,
512    **extra,
513):
514    if {TILE_M, TILE_N, SPLIT_N, num_warps, num_stages, GROUP_SIZE} == {None}:
515        device_name = torch.cuda.get_device_name()
516        meta = get_meta(
517            "scatter_mm",
518            (M, K, N, Ms, Ks),
519            device_name,
520            version=(0, torch.float16, 0.5),
521        )
522        if meta is not None:
523            meta.update(**extra)
524            return meta
525        # The following parameters are optimized for the performance
526        # equilibrium points of bsr-dense and dense-dense matrix
527        # multiplications when using GPU card NVIDIA GeForce RTX 2060
528        # SUPER. For points far from the performance equilibrium
529        # points as well as for other GPU cards, the optimal
530        # parameters are likely different from what specified below.
531        if (M, K, N) == (256,) * 3:
532            if (Ms, Ks) == (16, 16):
533                SPLIT_N = 1
534                TILE_M = 16
535                TILE_N = 16
536                GROUP_SIZE = 4
537                num_stages = 1
538                num_warps = 4  # noqa: E225,E231,E702
539            elif (Ms, Ks) == (32, 32):
540                SPLIT_N = 2
541                TILE_M = 32
542                TILE_N = 16
543                GROUP_SIZE = 4
544                num_stages = 1
545                num_warps = 4  # noqa: E225,E231,E702
546            elif (Ms, Ks) == (64, 64):
547                SPLIT_N = 1
548                TILE_M = 32
549                TILE_N = 32
550                GROUP_SIZE = 4
551                num_stages = 1
552                num_warps = 4  # noqa: E225,E231,E702
553            elif (Ms, Ks) == (128, 128):
554                SPLIT_N = 1
555                TILE_M = 32
556                TILE_N = 32
557                GROUP_SIZE = 2
558                num_stages = 1
559                num_warps = 4  # noqa: E225,E231,E702
560        elif (M, K, N) == (512,) * 3:
561            if (Ms, Ks) == (16, 16):
562                SPLIT_N = 8
563                TILE_M = 16
564                TILE_N = 64
565                GROUP_SIZE = 2
566                num_stages = 1
567                num_warps = 2  # noqa: E225,E231,E702
568            elif (Ms, Ks) == (32, 32):
569                SPLIT_N = 8
570                TILE_M = 32
571                TILE_N = 64
572                GROUP_SIZE = 4
573                num_stages = 1
574                num_warps = 2  # noqa: E225,E231,E702
575            elif (Ms, Ks) == (64, 64):
576                SPLIT_N = 4
577                TILE_M = 32
578                TILE_N = 128
579                GROUP_SIZE = 4
580                num_stages = 1
581                num_warps = 4  # noqa: E225,E231,E702
582            elif (Ms, Ks) == (128, 128):
583                SPLIT_N = 8
584                TILE_M = 64
585                TILE_N = 64
586                GROUP_SIZE = 4
587                num_stages = 1
588                num_warps = 4  # noqa: E225,E231,E702
589        elif (M, K, N) == (1024,) * 3:
590            if (Ms, Ks) == (16, 16):
591                SPLIT_N = 4
592                TILE_M = 16
593                TILE_N = 128
594                GROUP_SIZE = 2
595                num_stages = 1
596                num_warps = 1  # noqa: E225,E231,E702
597            elif (Ms, Ks) == (32, 32):
598                SPLIT_N = 8
599                TILE_M = 32
600                TILE_N = 64
601                GROUP_SIZE = 2
602                num_stages = 1
603                num_warps = 1  # noqa: E225,E231,E702
604            elif (Ms, Ks) == (64, 64):
605                SPLIT_N = 16
606                TILE_M = 64
607                TILE_N = 64
608                GROUP_SIZE = 4
609                num_stages = 1
610                num_warps = 2  # noqa: E225,E231,E702
611            elif (Ms, Ks) == (128, 128):
612                SPLIT_N = 16
613                TILE_M = 64
614                TILE_N = 64
615                GROUP_SIZE = 4
616                num_stages = 1
617                num_warps = 4  # noqa: E225,E231,E702
618            elif (Ms, Ks) == (256, 256):
619                SPLIT_N = 16
620                TILE_M = 64
621                TILE_N = 64
622                GROUP_SIZE = 2
623                num_stages = 1
624                num_warps = 4  # noqa: E225,E231,E702
625        elif (M, K, N) == (2048,) * 3:
626            if (Ms, Ks) == (16, 16):
627                SPLIT_N = 4
628                TILE_M = 16
629                TILE_N = 128
630                GROUP_SIZE = 8
631                num_stages = 1
632                num_warps = 1  # noqa: E225,E231,E702
633            elif (Ms, Ks) == (32, 32):
634                SPLIT_N = 4
635                TILE_M = 32
636                TILE_N = 64
637                GROUP_SIZE = 4
638                num_stages = 1
639                num_warps = 1  # noqa: E225,E231,E702
640            elif (Ms, Ks) == (64, 64):
641                SPLIT_N = 4
642                TILE_M = 64
643                TILE_N = 128
644                GROUP_SIZE = 4
645                num_stages = 1
646                num_warps = 4  # noqa: E225,E231,E702
647            elif (Ms, Ks) == (128, 128):
648                SPLIT_N = 8
649                TILE_M = 64
650                TILE_N = 64
651                GROUP_SIZE = 4
652                num_stages = 1
653                num_warps = 4  # noqa: E225,E231,E702
654            elif (Ms, Ks) == (256, 256):
655                SPLIT_N = 4
656                TILE_M = 64
657                TILE_N = 64
658                GROUP_SIZE = 2
659                num_stages = 1
660                num_warps = 4  # noqa: E225,E231,E702
661        elif (M, K, N) == (4096,) * 3:
662            if (Ms, Ks) == (16, 16):
663                SPLIT_N = 2
664                TILE_M = 16
665                TILE_N = 256
666                GROUP_SIZE = 2
667                num_stages = 1
668                num_warps = 2  # noqa: E225,E231,E702
669            elif (Ms, Ks) == (32, 32):
670                SPLIT_N = 2
671                TILE_M = 32
672                TILE_N = 64
673                GROUP_SIZE = 2
674                num_stages = 1
675                num_warps = 1  # noqa: E225,E231,E702
676            elif (Ms, Ks) == (64, 64):
677                SPLIT_N = 2
678                TILE_M = 64
679                TILE_N = 128
680                GROUP_SIZE = 2
681                num_stages = 1
682                num_warps = 4  # noqa: E225,E231,E702
683
684    if SPLIT_N is None:
685        # Assume NVIDIA GeForce RTX 2060 SUPER:
686        # With the probality of 92% (99.9% when N > 512), the
687        # performance will not be worse more than 2% from the
688        # performance when using an optimal value.  Otherwise, when N
689        # <= 512, using the following heuristics may give upto 15%
690        # lower performance.
691        SPLIT_N = {
692            16: 1,
693            32: 2,
694            64: 4,
695            128: 8,
696            256: 16,
697            512: 8,
698            1024: 16,
699            4096: 32,
700            8192: 64,
701        }.get(N, 16)
702        if Ms >= 512 and N >= 2048:
703            SPLIT_N = 1
704    Ns = N // SPLIT_N
705    if TILE_M is None:
706        TILE_M = min(64 if Ns < 512 else 32, Ms)
707    if TILE_N is None:
708        TILE_N = min(64 if Ns < 512 else 32, Ns)
709    num_stages = num_stages or 1
710    if num_warps is None:
711        if min(M, N) > 1024:
712            num_warps = {16: 1, 32: 1, 64: 2}.get(Ms, 4)
713        elif min(M, N) == 1024:
714            num_warps = {16: 1, 32: 1, 64: 2}.get(Ms, 4)
715        elif min(M, N) == 256:
716            num_warps = {16: 1, 32: 4}.get(Ms, 4)
717        else:
718            num_warps = {16: 1, 32: 2}.get(Ms, 4)
719    GROUP_SIZE = GROUP_SIZE or 4
720
721    assert TILE_M <= Ms, dict(TILE_M=TILE_M, Ms=Ms)
722    assert TILE_N <= Ns, dict(TILE_N=TILE_N, Ns=Ns)
723    assert Ms <= M, dict(M=M, Ms=Ms)
724    assert Ns <= N, dict(N=N, Ns=Ns)
725    assert Ks <= K, dict(K=K, Ks=Ks)
726
727    return dict(
728        TILE_M=TILE_M,
729        TILE_N=TILE_N,
730        GROUP_SIZE=GROUP_SIZE,
731        num_stages=num_stages,
732        num_warps=num_warps,
733        SPLIT_N=SPLIT_N,
734        **extra,
735    )
736
737
738def bsr_dense_addmm_meta(
739    M,
740    K,
741    N,
742    Ms,
743    Ks,
744    beta,
745    alpha,
746    SPLIT_N=None,
747    GROUP_SIZE_ROW=None,
748    num_warps=None,
749    num_stages=None,
750    sparsity=None,
751    dtype=None,
752    _version=0,
753    **extra,
754):
755    # Specifying _version is useful for situations when one wants to
756    # discard existing triton kernel tuning results, say, in testing
757    # bsr_dense_addmm_meta functionality.
758    if dtype is None:
759        dtype = torch.float16
760    if sparsity is None:
761        sparsity = 0.5
762    if {SPLIT_N, num_warps, num_stages, GROUP_SIZE_ROW} == {None}:
763        device_name = torch.cuda.get_device_name()
764        key = (M, K, N, Ms, Ks, beta == 0, beta == 1, alpha == 1)
765        meta = get_meta(
766            "bsr_dense_addmm", key, device_name, version=(_version, dtype, sparsity)
767        )
768        if meta is None and sparsity != 0.5:
769            meta = get_meta(
770                "bsr_dense_addmm", key, device_name, version=(_version, dtype, 0.5)
771            )
772        if meta is None:
773            # find approximate meta such that N % SPLIT_N == 0.
774            matching_meta = get_meta(
775                "bsr_dense_addmm",
776                (*key[:2], "*", *key[3:]),
777                device_name,
778                version=(_version, dtype, 0.5),
779            )
780            for mkey in sorted(matching_meta or {}):
781                meta_ = matching_meta[mkey]
782                n = mkey[2]
783                split_n = meta_["SPLIT_N"]
784                c = n // split_n
785                if N % c == 0 and n <= N:
786                    meta = dict(meta_)
787                    meta["SPLIT_N"] = N // c
788        if meta is not None:
789            meta.update(**extra)
790            return meta
791        else:
792            # see [Computing optimal kernel parameters] in
793            # _triton_ops_meta.py for ways to avoid this warning
794            # message
795            warn_once(
796                f"bsr_dense_addmm uses non-optimal triton kernel parameters for {M=} {K=} {N=} {Ms=}, {Ks=} {beta=} {alpha=}"
797            )
798
799    SPLIT_N = SPLIT_N or max(N // Ms, 1)
800    GROUP_SIZE_ROW = GROUP_SIZE_ROW or 4
801    num_stages = num_stages or 1
802    num_warps = num_warps or 4
803    return dict(
804        SPLIT_N=SPLIT_N,
805        GROUP_SIZE_ROW=GROUP_SIZE_ROW,
806        num_stages=num_stages,
807        num_warps=num_warps,
808        **extra,
809    )
810
811
812class TensorAsKey:
813    """A light-weight wrapper of a tensor that enables storing tensors as
814    keys with efficient memory reference based comparision as an
815    approximation to data equality based keys.
816
817    Motivation: the hash value of a torch tensor is tensor instance
818    based that does not use data equality and makes the usage of
819    tensors as keys less useful. For instance, the result of
820    ``len({a.crow_indices(), a.crow_indices()})`` is `2`, although,
821    the tensor results from `crow_indices` method call are equal, in
822    fact, these share the same data storage.
823    On the other hand, for efficient caching of tensors we want to
824    avoid calling torch.equal that compares tensors item-wise.
825
826    TensorAsKey offers a compromise in that it guarantees key equality
827    of tensors that references data in the same storage in the same
828    manner and without accessing underlying data. However, this
829    approach does not always guarantee correctness. For instance, for
830    a complex tensor ``x``, we have ``TensorAsKey(x) ==
831    TensorAsKey(x.conj())`` while ``torch.equal(x, x.conj())`` would
832    return False.
833    """
834
835    def __init__(self, obj):
836        def get_tensor_key(obj):
837            # Warning: TensorAsKey does not track negative nor
838            # conjugate bits of its input object because in the use
839            # case of wrapping compressed/plain indices of compressed
840            # sparse tensors (that are always integer tensors with
841            # non-negative items) these bits are never set. However,
842            # when extending the use of TensorAsKey to float or
843            # complex tensors, the values of these bits (see is_neg
844            # and is_conj methods) must be included in the key as
845            # well.
846            assert not (obj.dtype.is_floating_point or obj.dtype.is_complex), obj.dtype
847            return (
848                obj.data_ptr(),
849                obj.storage_offset(),
850                obj.shape,
851                obj.stride(),
852                obj.dtype,
853            )
854
855        self._obj_ref = weakref.ref(obj)
856        if obj.layout is torch.strided:
857            self.key = get_tensor_key(obj)
858        elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}:
859            self.key = (
860                get_tensor_key(obj.crow_indices()),
861                get_tensor_key(obj.col_indices()),
862            )
863        elif obj.layout in {torch.sparse_csc, torch.sparse_bsc}:
864            self.key = (
865                get_tensor_key(obj.ccol_indices()),
866                get_tensor_key(obj.row_indices()),
867            )
868        else:
869            raise NotImplementedError(obj.layout)
870        self._hash = hash(self.key)
871
872    def __hash__(self):
873        return self._hash
874
875    def __eq__(self, other):
876        if not isinstance(other, TensorAsKey):
877            return False
878        if self.obj is None or other.obj is None:
879            # dead objects always compare unequal unless these are
880            # same objects
881            return self is other
882        return self.key == other.key
883
884    @property
885    def obj(self):
886        """Return object if alive, otherwise None."""
887        return self._obj_ref()
888
889
890@lru_cache(maxsize=TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE)
891def _bsr_scatter_mm_indices_data(
892    indices_format, M, K, N, Ms, Ks, nbatches, SPLIT_N, compressed_sparse_tensor_as_key
893):
894    bsr = compressed_sparse_tensor_as_key.obj
895    assert bsr is not None
896    crow_indices, col_indices = bsr.crow_indices(), bsr.col_indices()
897    device = crow_indices.device
898    indices_dtype = torch.int32
899
900    if indices_format == "bsr_strided_mm_compressed":
901        Ns = N // SPLIT_N
902        q_offsets_lst = []
903        b = torch.arange(SPLIT_N, dtype=indices_dtype, device=device) * Ns
904        for m in range(M // Ms):
905            r0 = crow_indices[m].item()
906            r1 = crow_indices[m + 1].item()
907            if r1 == r0:
908                continue
909            q_offsets_lst.append(
910                (col_indices[r0:r1] * (Ks * N)).repeat(SPLIT_N)
911                + b.repeat_interleave(r1 - r0)
912            )
913        q_offsets = torch.cat(q_offsets_lst)
914        crow_indices_diff = crow_indices.diff()
915        non_zero_row_indices = crow_indices_diff.nonzero()
916        a = non_zero_row_indices * (Ms * N)
917        r_offsets = (a + b).view(-1)
918        c_indices = crow_indices
919        # swizzle operation: mm elements with longer sums are computed first:
920        nnz_per_row = crow_indices_diff[non_zero_row_indices].repeat_interleave(SPLIT_N)
921        nnz_per_row, indices = nnz_per_row.sort(descending=True, stable=True)
922        r_offsets = r_offsets[indices]
923        return (indices_format, c_indices, r_offsets, q_offsets)
924
925    elif indices_format == "bsr_strided_mm":
926        Ns = N // SPLIT_N
927        p_offsets_lst = []
928        q_offsets_lst = []
929        b = torch.arange(SPLIT_N, dtype=indices_dtype, device=device) * Ns
930        for m in range(M // Ms):
931            r0 = crow_indices[m].item()
932            r1 = crow_indices[m + 1].item()
933            if r1 == r0:
934                continue
935            p_offsets_lst.append(
936                torch.arange(r0, r1, dtype=indices_dtype, device=device).repeat(SPLIT_N)
937            )
938            q_offsets_lst.append(
939                (col_indices[r0:r1] * (Ks * N)).repeat(SPLIT_N)
940                + b.repeat_interleave(r1 - r0)
941            )
942        q_offsets = torch.cat(q_offsets_lst)
943        crow_indices_diff = crow_indices.diff()
944        non_zero_row_indices = crow_indices_diff.nonzero()
945        a = non_zero_row_indices * (Ms * N)
946        r_offsets = (a + b).view(-1)
947        c_indices = torch.cat(
948            (
949                crow_indices[:1],
950                torch.cumsum(
951                    crow_indices_diff[non_zero_row_indices].repeat_interleave(SPLIT_N),
952                    0,
953                ),
954            )
955        )
956        p_offsets = torch.cat(p_offsets_lst)
957        return (indices_format, c_indices, r_offsets, p_offsets, q_offsets)
958
959    elif indices_format == "scatter_mm":
960        Ns = Ms
961        c_indices = [0]
962        pq_offsets = []
963        # todo: eliminate inner for-loops for efficiency
964        for b in range(nbatches):
965            for m in range(M // Ms):
966                r0 = crow_indices[m].item()
967                r1 = crow_indices[m + 1].item()
968                for n in range(N // Ns):
969                    c_indices.append(c_indices[-1] + r1 - r0)
970                    for t in range(r1 - r0):
971                        p = r0 + t
972                        q = (col_indices[p].item() + b * (K // Ks)) * (N // Ns) + n
973                        pq_offsets.append([p, q])
974
975        return (
976            indices_format,
977            torch.tensor(c_indices, dtype=indices_dtype, device=device),
978            torch.tensor(pq_offsets, dtype=indices_dtype, device=device),
979        )
980
981    else:
982        raise ValueError(
983            f"Invalid {indices_format=}. Expected bsr_strided_mm_compressed|bsr_strided_mm|scatter_mm"
984        )
985
986
987def bsr_scatter_mm_indices_data(
988    bsr, other, indices_format="bsr_strided_mm_compressed", **meta_input
989):
990    """Computes indices data for :func:`scatter_mm` used in BSR and
991    strided tensor matrix multiplication.
992    """
993    assert bsr.dense_dim() == 0
994    assert bsr.ndim == 2  # no batch dims
995    crow_indices = bsr.crow_indices()
996    col_indices = bsr.col_indices()
997    blocksize = bsr.values().shape[-2:]
998    M, K = bsr.shape
999    Ms, Ks = blocksize
1000    K_, N = other.shape[-2:]
1001    assert K_ == K
1002    nbatches = other.shape[:-2].numel()
1003
1004    meta = scatter_mm_meta(M, K, N, Ms, Ks, **meta_input)
1005    if "allow_tf32" not in meta_input:
1006        meta.update(allow_tf32=bsr.dtype in {torch.float16, torch.bfloat16})
1007    SPLIT_N = meta["SPLIT_N"]
1008    indices_data = _bsr_scatter_mm_indices_data(
1009        indices_format, M, K, N, Ms, Ks, nbatches, SPLIT_N, TensorAsKey(bsr)
1010    )
1011
1012    if indices_format == "bsr_strided_mm_compressed":
1013        meta.update(is_compressed=True)
1014        return indices_data + (meta,)
1015    elif indices_format == "bsr_strided_mm":
1016        meta.update(is_compressed=False)
1017        return indices_data + (meta,)
1018    else:
1019        return indices_data
1020
1021
1022def bsr_scatter_mm(bsr, other, indices_data=None, out=None):
1023    """BSR @ strided -> strided"""
1024
1025    assert bsr.ndim == 2
1026    assert other.ndim >= 2
1027
1028    Ms, Ks, Ns = bsr.shape[-2], bsr.shape[-1], other.shape[-1]
1029    blocksize = bsr.values().shape[-2:]
1030
1031    if indices_data is None:
1032        indices_data = bsr_scatter_mm_indices_data(
1033            bsr, other, indices_format="bsr_strided_mm_compressed"
1034        )
1035
1036    indices_format = indices_data[0]
1037
1038    if out is None:
1039        out = torch.empty(
1040            (*other.shape[:-2], Ms, Ns), dtype=bsr.dtype, device=bsr.device
1041        )
1042    out_shape = out.shape
1043    out = as1Dbatch(out)
1044
1045    if bsr._nnz() == 0:
1046        out.zero_()
1047    elif indices_format in {"bsr_strided_mm_compressed", "bsr_strided_mm"}:
1048        out.zero_()
1049        scatter_mm(bsr.values(), other, indices_data, accumulators=out)
1050    elif indices_format == "scatter_mm":
1051        nbatches = other.shape[:-2].numel()
1052        accumulators = torch.zeros(
1053            (
1054                nbatches * Ms // blocksize[0] * Ns // blocksize[0],
1055                blocksize[0],
1056                blocksize[0],
1057            ),
1058            dtype=bsr.dtype,
1059            device=bsr.device,
1060        )
1061        others = (
1062            as1Dbatch(other)
1063            .transpose(-2, -1)
1064            .view(
1065                nbatches,
1066                Ns // blocksize[0],
1067                blocksize[0],
1068                Ks // blocksize[1],
1069                blocksize[1],
1070            )
1071            .movedim(
1072                (3, 1, 4, 2), (1, 2, 3, 4)
1073            )  # equivalent to .transpose(-3, -2).transpose(-2, -1).transpose(-4, -3)
1074            .flatten(0, 2)
1075        )
1076        scatter_mm(bsr.values(), others, indices_data, accumulators=accumulators)
1077        out.copy_(
1078            accumulators.unflatten(
1079                0, (nbatches, Ms // blocksize[0], Ns // blocksize[0])
1080            )
1081            .movedim(
1082                (1, 2, 3, 4), (3, 1, 4, 2)
1083            )  # equivalent to .transpose(-4, -3).transpose(-2, -1).transpose(-3, -2)
1084            .reshape(nbatches, Ns, Ms)
1085            .transpose(-2, -1)
1086        )
1087    else:
1088        raise NotImplementedError(indices_format)
1089
1090    return out.view(out_shape)
1091
1092
1093def _int_bsr_dense_addmm(
1094    input: torch.Tensor,
1095    bsr: torch.Tensor,
1096    dense: torch.Tensor,
1097    *,
1098    beta=1,
1099    alpha=1,
1100    out: Optional[torch.Tensor] = None,
1101    skip_checks: bool = False,
1102    max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
1103    meta: Optional[dict] = None,
1104):
1105    if out is None and dense.dtype is torch.int8:
1106        f_name = "_int_bsr_dense_addmm"
1107        crow_indices = bsr.crow_indices()
1108        batch_ndim = crow_indices.dim() - 1
1109        M = bsr.shape[batch_ndim]
1110        N = dense.shape[-1]
1111        original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense)
1112        out = torch.empty(
1113            original_batch_dims_broadcasted + (M, N),
1114            dtype=torch.int32,
1115            device=dense.device,
1116        )
1117    return bsr_dense_addmm(
1118        input,
1119        bsr,
1120        dense,
1121        beta=beta,
1122        alpha=alpha,
1123        out=out,
1124        skip_checks=skip_checks,
1125        max_grid=max_grid,
1126        meta=meta,
1127    )
1128
1129
1130def bsr_dense_addmm(
1131    input: torch.Tensor,
1132    bsr: torch.Tensor,
1133    dense: torch.Tensor,
1134    *,
1135    beta=1,
1136    alpha=1,
1137    out: Optional[torch.Tensor] = None,
1138    skip_checks: bool = False,
1139    max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
1140    meta: Optional[dict] = None,
1141):
1142    f_name = "bsr_dense_addmm"
1143    values = bsr.values()
1144    crow_indices = bsr.crow_indices()
1145    col_indices = bsr.col_indices()
1146    batch_ndim = crow_indices.dim() - 1
1147    M, K = bsr.shape[batch_ndim : batch_ndim + 2]
1148    blocksize = values.shape[batch_ndim + 1 : batch_ndim + 3]
1149    N = dense.shape[-1]
1150
1151    # todo: implement checks
1152
1153    if out is None:
1154        original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense)
1155        out = dense.new_empty(original_batch_dims_broadcasted + (M, N))
1156
1157    if bsr._nnz() == 0 or alpha == 0 or N == 0 or M == 0 or K == 0:
1158        if beta == 0:
1159            out.zero_()
1160        else:
1161            out.copy_(input)
1162            if beta != 1:
1163                out.mul_(beta)
1164        return out
1165
1166    if meta is None:
1167        sparsity = round(1 - bsr._nnz() * blocksize[0] * blocksize[1] / (M * K), 2)
1168        meta = bsr_dense_addmm_meta(
1169            M,
1170            K,
1171            N,
1172            blocksize[0],
1173            blocksize[1],
1174            beta,
1175            alpha,
1176            sparsity=sparsity,
1177            dtype=out.dtype,
1178        )
1179    out_backup = out
1180
1181    crow_indices, col_indices, values, input, dense, out = prepare_inputs(
1182        bsr, input, dense, out
1183    )
1184
1185    BM, BK = blocksize
1186    SPLIT_N = meta.get("SPLIT_N", N // BM)
1187    BN = N // SPLIT_N
1188
1189    out_untiled = out
1190    out = tile_to_blocksize(out, (BM, BN))
1191    dense = tile_to_blocksize(dense, (BK, BN))
1192    input = tile_to_blocksize(input, (BM, BN))
1193
1194    dot_out_dtype = {
1195        torch.float16: tl.float32,
1196        torch.bfloat16: tl.float32,
1197        torch.float32: tl.float64,
1198        torch.float64: tl.float64,
1199        torch.int8: tl.int32,
1200        torch.int32: tl.int32,
1201    }[out.dtype]
1202
1203    n_batches = dense.size(0)
1204    n_block_rows = crow_indices.size(-1) - 1
1205    n_block_cols = dense.size(-3)
1206
1207    full_grid = (n_batches, n_block_cols, n_block_rows)
1208    if max_grid is not None:
1209        grid_blocks = tuple(max_grid[:3][::-1]) + (None,) * (3 - len(max_grid[:3]))
1210    else:
1211        grid_blocks = None
1212
1213    tensor_dims_map = {
1214        values: (0, None, None),
1215        crow_indices: (0, None, -1),
1216        col_indices: (0, None, None),
1217        input: (0, -3, -4),
1218        dense: (0, -3, None),
1219        out: (0, -3, -4),
1220    }
1221
1222    assert alpha != 0
1223
1224    def kernel(grid, *sliced_tensors):
1225        _bsr_strided_addmm_kernel[grid](
1226            *ptr_stride_extractor(*sliced_tensors),
1227            beta,
1228            alpha,
1229            beta_is_one=beta == 1,
1230            beta_is_nonzero=beta != 0,
1231            alpha_is_one=alpha == 1,
1232            BLOCKSIZE_ROW=BM,
1233            BLOCKSIZE_INNER=BK,
1234            BLOCKSIZE_COL=BN,
1235            allow_tf32=dot_out_dtype == tl.float32,
1236            acc_dtype=dot_out_dtype,
1237            **meta,
1238        )
1239
1240    launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)
1241
1242    if out.data_ptr() != out_backup.data_ptr():
1243        # prepare_inputs has made a copy of out, copy its content back
1244        # to out_backup:
1245        out_backup.copy_(out_untiled.view(out_backup.shape))
1246
1247    return out_backup
1248
1249
1250if has_triton():
1251    import triton
1252    import triton.language as tl
1253
1254    @triton.jit
1255    def _sampled_addmm_kernel(
1256        alpha,
1257        beta,
1258        IS_BETA_ZERO: tl.constexpr,
1259        BLOCKSIZE_ROW: tl.constexpr,
1260        BLOCKSIZE_COL: tl.constexpr,
1261        k,
1262        TILE_K: tl.constexpr,
1263        values_ptr,
1264        values_batch_stride,
1265        values_nnz_stride,
1266        values_row_block_stride,
1267        values_col_block_stride,
1268        crow_indices_ptr,
1269        crow_indices_batch_stride,
1270        crow_indices_stride,
1271        col_indices_ptr,
1272        col_indices_batch_stride,
1273        col_indices_stride,
1274        mat1_ptr,
1275        mat1_batch_stride,
1276        mat1_tiled_row_stride,
1277        mat1_tiled_col_stride,
1278        mat1_row_block_stride,
1279        mat1_col_block_stride,
1280        mat2_ptr,
1281        mat2_batch_stride,
1282        mat2_tiled_row_stride,
1283        mat2_tiled_col_stride,
1284        mat2_row_block_stride,
1285        mat2_col_block_stride,
1286        acc_dtype: tl.constexpr,
1287        allow_tf32: tl.constexpr,
1288    ):
1289        batch_pid = tl.program_id(axis=1)
1290        row_block_pid = tl.program_id(axis=0)
1291
1292        crow_indices_offset_ptr = (
1293            crow_indices_ptr
1294            + crow_indices_batch_stride * batch_pid
1295            + crow_indices_stride * row_block_pid
1296        )
1297        nnz_offset = tl.load(crow_indices_offset_ptr)
1298        nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)
1299
1300        # Compute nnz for the row with number row_block_pid.
1301        # If it is zero, skip the row.
1302        row_nnz = nnz_offset_next - nnz_offset
1303        if row_nnz == 0:
1304            return
1305
1306        row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
1307        col_block_arange = tl.arange(0, BLOCKSIZE_COL)
1308
1309        # Pointers are set to the first block of the current row.
1310        values_block_ptrs = (
1311            values_ptr
1312            + values_batch_stride * batch_pid
1313            + values_nnz_stride * nnz_offset
1314            + values_row_block_stride * row_block_arange[:, None]
1315            + values_col_block_stride * col_block_arange[None, :]
1316        )
1317
1318        col_index_nnz_ptr = (
1319            col_indices_ptr
1320            + col_indices_batch_stride * batch_pid
1321            + col_indices_stride * nnz_offset
1322        )
1323
1324        # Advance mat1 to the current tiled row, ignore columns.
1325        mat1_block_ptrs = (
1326            mat1_ptr
1327            + mat1_batch_stride * batch_pid
1328            + mat1_tiled_row_stride * row_block_pid
1329            + mat1_row_block_stride * row_block_arange[:, None]
1330        )
1331
1332        # Advance mat2 in batch and block col dimension.
1333        mat2_block_ptrs = (
1334            mat2_ptr
1335            + mat2_batch_stride * batch_pid
1336            + mat2_col_block_stride * col_block_arange[None, :]
1337        )
1338
1339        k_tile_arange = tl.arange(0, TILE_K)
1340        for _ in range(row_nnz):
1341            acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)
1342
1343            # find column block index
1344            col_block = tl.load(col_index_nnz_ptr)
1345
1346            for k_tile in range(0, k, TILE_K):
1347                k_offsets = k_tile + k_tile_arange
1348                mask_k = k_offsets < k
1349
1350                mat1_block = tl.load(
1351                    mat1_block_ptrs + mat1_col_block_stride * k_offsets[None, :],
1352                    mask=mask_k[None, :],
1353                    other=0.0,
1354                )
1355
1356                mat2_block = tl.load(
1357                    mat2_block_ptrs
1358                    + mat2_tiled_col_stride * col_block
1359                    + mat2_row_block_stride * k_offsets[:, None],
1360                    mask=mask_k[:, None],
1361                    other=0.0,
1362                )
1363
1364                acc_block += tl.dot(
1365                    mat1_block, mat2_block, allow_tf32=allow_tf32, out_dtype=acc_dtype
1366                )
1367
1368            if IS_BETA_ZERO:
1369                acc_block *= alpha
1370            else:
1371                acc_block = alpha * acc_block + beta * tl.load(values_block_ptrs)
1372
1373            # write result
1374            tl.store(values_block_ptrs, acc_block.to(values_ptr.dtype.element_ty))
1375
1376            # advance val/col_index ptrs to the next block in the row.
1377            values_block_ptrs += values_nnz_stride
1378            col_index_nnz_ptr += col_indices_stride
1379
1380    @triton.jit
1381    def _bsr_strided_dense_rowspace_kernel(
1382        # values prologue
1383        values_ptr,
1384        values_batch_stride,
1385        values_nnz_stride,
1386        values_row_block_stride,
1387        values_col_block_stride,
1388        # values epilogue
1389        # crow_indices prologue
1390        crow_indices_ptr,
1391        crow_indices_batch_stride,
1392        crow_indices_stride,
1393        # crow_indices epilogue
1394        # col_indices prologue
1395        col_indices_ptr,
1396        col_indices_batch_stride,
1397        col_indices_stride,
1398        # col_indices epilogue
1399        # dense prologue
1400        dense_ptr,
1401        dense_batch_stride,
1402        dense_tiled_row_stride,
1403        dense_tiled_col_stride,
1404        dense_row_block_stride,
1405        dense_col_block_stride,
1406        # dense epilogue
1407        # output prologue
1408        output_ptr,
1409        output_batch_stride,
1410        output_tiled_row_stride,
1411        output_tiled_col_stride,
1412        output_row_block_stride,
1413        output_col_block_stride,
1414        # output epilogue
1415        #
1416        # gh-113754: Always keep all constexpr arguments at the end of
1417        # triton kernel arguments list because with triton 2.1 or
1418        # earlier non-contiguous outputs will corrupt CUDA state due
1419        # to a triton bug (fixed in openai/triton#2262).
1420        BLOCKSIZE_ROW: tl.constexpr,
1421        BLOCKSIZE_COL: tl.constexpr,
1422        acc_dtype: tl.constexpr,
1423        allow_tf32: tl.constexpr,
1424        GROUP_SIZE_ROW: tl.constexpr,
1425    ):
1426        batch_pid = tl.program_id(axis=2)
1427        row_block_pid = tl.program_id(axis=0)
1428        col_block_pid = tl.program_id(axis=1)
1429        n_block_rows = tl.num_programs(axis=0)
1430        n_block_cols = tl.num_programs(axis=1)
1431
1432        row_block_pid, col_block_pid = tl.swizzle2d(
1433            row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW
1434        )
1435
1436        crow_indices_offset_ptr = (
1437            crow_indices_ptr
1438            + crow_indices_batch_stride * batch_pid
1439            + crow_indices_stride * row_block_pid
1440        )
1441        nnz_offset = tl.load(crow_indices_offset_ptr)
1442        nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)
1443
1444        # Compute nnz for the row with number row_block_pid.
1445        # If it is zero, skip the row.
1446        row_nnz = nnz_offset_next - nnz_offset
1447        if row_nnz == 0:
1448            return
1449
1450        row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
1451        col_block_arange = tl.arange(0, BLOCKSIZE_COL)
1452
1453        # Pointers are set to the first block of the current row.
1454        values_block_ptrs = (
1455            values_ptr
1456            + values_batch_stride * batch_pid
1457            + values_nnz_stride * nnz_offset
1458            + values_row_block_stride * row_block_arange[:, None]
1459            + values_col_block_stride * col_block_arange[None, :]
1460        )
1461
1462        # NOTE: dense is advanced into all dimensions but the tiled row one.
1463        # That will be advanced in the loop according to values in col_indices.
1464        dense_block_ptrs = (
1465            dense_ptr
1466            + dense_batch_stride * batch_pid
1467            + dense_tiled_col_stride * col_block_pid
1468            + dense_row_block_stride * col_block_arange[:, None]
1469            + dense_col_block_stride * row_block_arange[None, :]
1470        )
1471
1472        # Pointers are set to exact write-to locations
1473        output_ptrs = (
1474            output_ptr
1475            + output_batch_stride * batch_pid
1476            + output_tiled_row_stride * row_block_pid
1477            + output_tiled_col_stride * col_block_pid
1478            + output_row_block_stride * row_block_arange[:, None]
1479            + output_col_block_stride * row_block_arange[None, :]
1480        )
1481
1482        # Set pointer to the first nonzero element in the current row
1483        col_index_nnz_ptr = (
1484            col_indices_ptr
1485            + col_indices_batch_stride * batch_pid
1486            + col_indices_stride * nnz_offset
1487        )
1488
1489        output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)
1490        for _ in range(row_nnz):
1491            values_block = tl.load(values_block_ptrs)
1492
1493            # find which row of dense needs to get loaded
1494            # for multiplication with values_block.
1495            dense_row_idx = tl.load(col_index_nnz_ptr)
1496            dense_block = tl.load(
1497                dense_block_ptrs + dense_tiled_row_stride * dense_row_idx
1498            )
1499
1500            # do block mm
1501            output_acc_block += tl.dot(
1502                values_block, dense_block, allow_tf32=allow_tf32, out_dtype=acc_dtype
1503            )
1504
1505            # move val/col_index ptrs to the next block in the row
1506            values_block_ptrs += values_nnz_stride
1507            col_index_nnz_ptr += col_indices_stride
1508
1509        # write back the result
1510        tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))
1511
1512    def _run_sampled_addmm_kernel(
1513        alpha,
1514        beta,
1515        is_beta_zero,
1516        blocksize,
1517        k,
1518        tile_k,
1519        values,
1520        crow_indices,
1521        col_indices,
1522        mat1,
1523        mat2,
1524        max_grid,
1525    ):
1526        n_batches = values.size(0)
1527        n_block_rows = crow_indices.size(-1) - 1
1528
1529        full_grid = (n_batches, n_block_rows)
1530        if max_grid is not None:
1531            grid_blocks = tuple(max_grid[:2][::-1]) + (None,) * (2 - len(max_grid[:2]))
1532        else:
1533            grid_blocks = None
1534        tensor_dims_map = {
1535            values: (0, None),
1536            crow_indices: (0, -1),
1537            col_indices: (0, None),
1538            mat1: (0, -4),
1539            mat2: (0, None),
1540        }
1541        if values.dtype in (torch.half, torch.bfloat16):
1542            acc_dtype = tl.float32
1543            allow_tf32 = True
1544        else:
1545            acc_dtype = tl.float64
1546            allow_tf32 = False
1547
1548        def kernel(grid, *sliced_tensors):
1549            _sampled_addmm_kernel[grid](
1550                alpha,
1551                beta,
1552                is_beta_zero,
1553                *blocksize,
1554                k,
1555                tile_k,
1556                *ptr_stride_extractor(*sliced_tensors),
1557                acc_dtype=acc_dtype,
1558                allow_tf32=allow_tf32,
1559                num_stages=1,
1560                num_warps=4,
1561            )
1562
1563        launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)
1564
1565    def sampled_addmm(
1566        input: torch.Tensor,
1567        mat1: torch.Tensor,
1568        mat2: torch.Tensor,
1569        *,
1570        beta=1.0,
1571        alpha=1.0,
1572        out: Optional[torch.Tensor] = None,
1573        skip_checks: bool = False,
1574        max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
1575    ):
1576        f_name = "sampled_addmm"
1577
1578        check_bsr_layout(f_name, input)
1579        input_broadcasted = broadcast_batch_dims_bsr(f_name, input, mat1, mat2)
1580
1581        if not skip_checks:
1582            check_device(f_name, mat1, input.device)
1583            check_device(f_name, mat2, input.device)
1584            if beta != 0.0 and input.dtype is torch.bool:
1585                check(
1586                    False,
1587                    f"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed.",
1588                )
1589            if input.dtype is not torch.bool:
1590                check_dtype(f_name, mat1, input.dtype)
1591                check_dtype(f_name, mat2, input.dtype)
1592            else:
1593                check_dtype(f_name, mat1, mat2.dtype)
1594            check_mm_compatible_shapes(f_name, mat1, mat2)
1595            if out is not None:
1596                check_bsr_layout(f_name, out)
1597                check_device(f_name, out, mat1.device)
1598                check_dtype(f_name, out, input.dtype)
1599                check(
1600                    out.shape == input_broadcasted.shape and out._nnz() == input._nnz(),
1601                    f"{f_name}(): Expects `out` to be of shape {input_broadcasted.shape} "
1602                    f"and with nnz equal to {input_broadcasted._nnz()} "
1603                    f"but got out.shape = {out.shape} and out.nnz = {out._nnz()}",
1604                )
1605
1606        if out is None:
1607            out = input_broadcasted.to(mat1.dtype, copy=True)
1608        else:
1609            out.copy_(input_broadcasted)
1610
1611        if out.numel() == 0 or out._nnz() == 0:
1612            return out
1613
1614        blocksize = out.values().shape[-2:]
1615        m = mat1.size(-2)
1616        n = mat2.size(-1)
1617        k = mat1.size(-1)
1618
1619        # NOTE: (m, 0) @ (0, n) == zeros(m, n)
1620        if alpha == 0.0 or k == 0:
1621            out.values().mul_(beta)
1622            return out
1623
1624        # prepare inputs by reshaping them to be kernel-compatible
1625        out_backup = out
1626        crow_indices, col_indices, values, mat1, mat2 = prepare_inputs(out, mat1, mat2)
1627
1628        mat1 = tile_to_blocksize(mat1, (blocksize[0], k))
1629        mat2 = tile_to_blocksize(mat2, (k, blocksize[1]))
1630        tile_k = max(*blocksize)
1631
1632        _run_sampled_addmm_kernel(
1633            alpha,
1634            beta,
1635            beta == 0.0,
1636            blocksize,
1637            k,
1638            tile_k,
1639            values,
1640            crow_indices,
1641            col_indices,
1642            mat1,
1643            mat2,
1644            max_grid,
1645        )
1646
1647        # If nnz x block strides are not the same in out_backup.values and values,
1648        # it means that out_backup.values and values are not the views of each other,
1649        # so we have to copy.
1650        if out_backup.values().stride()[-3:] != values.stride()[-3:]:
1651            out_backup.values().copy_(values.reshape(out_backup.values().shape))
1652        return out_backup
1653
1654    def bsr_dense_mm(
1655        bsr: torch.Tensor,
1656        dense: torch.Tensor,
1657        *,
1658        out: Optional[torch.Tensor] = None,
1659        skip_checks: bool = False,
1660        max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
1661        meta: Optional[dict] = None,
1662    ):
1663        f_name = "bsr_dense_mm"
1664        m, kl = bsr.shape[-2:]
1665        if not skip_checks:
1666            check_bsr_layout(f_name, bsr)
1667            check_device(f_name, bsr, dense.device)
1668            check_dtype(f_name, bsr, dense.dtype, (torch.int8,))
1669            check_mm_compatible_shapes(f_name, bsr, dense)
1670
1671            n = dense.size(-1)
1672            row_block, col_block = bsr.values().shape[-2:]
1673            check_blocksize(f_name, (row_block, col_block))
1674            check(
1675                not n % 16,
1676                f"{f_name}(): dense.size(-1) == {n} should be divisible by 16",
1677            )
1678        else:
1679            kr, n = dense.shape[-2:]
1680
1681        original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense)
1682
1683        if out is not None and not skip_checks:
1684            expected_out_shape = original_batch_dims_broadcasted + (m, n)
1685            check(
1686                out.shape == expected_out_shape,
1687                "bsr_dense_mm(): `out` argument has wrong shape, "
1688                f"expected {expected_out_shape}, but got {out.shape}.",
1689            )
1690            check(
1691                out.is_contiguous() or out.transpose(-2, -1).is_contiguous(),
1692                "bsr_dense_mm(): only row-major/col-major `out` arguments are supported, "
1693                "i.e. (out.is_contiguous() or out.transpose(-2, -1).is_contiguous()) "
1694                "should be True.",
1695            )
1696
1697        # Allocate out
1698        if out is None:
1699            out = dense.new_empty(original_batch_dims_broadcasted + (m, n))
1700
1701        # Short circuit if lhs is zero
1702        if bsr._nnz() == 0:
1703            return out.zero_()
1704
1705        # with beta==0, addmm ignores input content, so we can use out
1706        # as a placeholder for input because their shapes match:
1707        return bsr_dense_addmm(out, bsr, dense, alpha=1, beta=0, out=out)
1708
1709    @triton.jit
1710    def _bsr_softmax_kernel(
1711        crow_indices_ptr,
1712        crow_indices_batch_stride,
1713        crow_indices_stride,
1714        values_ptr,
1715        values_batch_stride,
1716        values_row_block_stride,
1717        values_nnz_col_block_stride,
1718        row_block,
1719        col_block,
1720        MAX_ROW_NNZ: tl.constexpr,
1721        TILE: tl.constexpr,
1722    ):
1723        batch_pid = tl.program_id(axis=2)
1724        row_block_offset_pid = tl.program_id(axis=1)
1725        row_block_pid = tl.program_id(axis=0)
1726
1727        crow_indices_offset_ptr = (
1728            crow_indices_ptr
1729            + crow_indices_batch_stride * batch_pid
1730            + crow_indices_stride * row_block_pid
1731        )
1732        nnz_offset = tl.load(crow_indices_offset_ptr)
1733        nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)
1734
1735        # Compute nnz for the row with number row_block_pid.
1736        # If it is zero, skip the row.
1737        row_nnz = nnz_offset_next - nnz_offset
1738        if row_nnz == 0:
1739            return
1740
1741        row_arange = tl.arange(0, TILE)
1742        mask = row_arange < row_nnz * col_block
1743
1744        curr_row_values_ptrs = (
1745            values_ptr
1746            + values_batch_stride * batch_pid
1747            + values_row_block_stride * row_block_offset_pid
1748            + nnz_offset * col_block
1749        )
1750
1751        # find max in the row
1752        row_tile = tl.load(
1753            curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf")
1754        ).to(tl.float32)
1755        max_row_value = tl.max(row_tile, axis=0)
1756        for _ in range(TILE, MAX_ROW_NNZ, TILE):
1757            row_arange += TILE
1758            mask = row_arange < row_nnz * col_block
1759            row_tile = tl.load(
1760                curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf")
1761            ).to(tl.float32)
1762            curr_max_row_value = tl.max(row_tile, axis=0)
1763            max_row_value = tl.where(
1764                max_row_value > curr_max_row_value, max_row_value, curr_max_row_value
1765            )
1766
1767        # find denominator for stable softmax
1768        num = tl.exp(row_tile - max_row_value)
1769        denom = tl.sum(num, axis=0)
1770        for _ in range(TILE, MAX_ROW_NNZ, TILE):
1771            row_arange -= TILE
1772            mask = row_arange < row_nnz * col_block
1773            row_tile = tl.load(
1774                curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf")
1775            ).to(tl.float32)
1776            num = tl.exp(row_tile - max_row_value)
1777            denom += tl.sum(num, axis=0)
1778
1779        # populate output
1780        tl.store(
1781            curr_row_values_ptrs + row_arange,
1782            (num / denom).to(values_ptr.dtype.element_ty),
1783            mask=mask,
1784        )
1785        for _ in range(TILE, MAX_ROW_NNZ, TILE):
1786            row_arange += TILE
1787            mask = row_arange < row_nnz * col_block
1788            row_tile = tl.load(
1789                curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf")
1790            ).to(tl.float32)
1791            num = tl.exp(row_tile - max_row_value)
1792            tl.store(
1793                curr_row_values_ptrs + row_arange,
1794                (num / denom).to(values_ptr.dtype.element_ty),
1795                mask=mask,
1796            )
1797
1798    def bsr_softmax(input, max_row_nnz=None):
1799        f_name = "bsr_softmax"
1800
1801        check_bsr_layout(f_name, input)
1802        check_dtype(f_name, input, input.dtype)
1803
1804        if input._nnz() == 0 or input.numel() == 0:
1805            return input.clone()
1806
1807        m, n = input.shape[-2:]
1808        nnz = input._nnz()
1809        row_block, col_block = input.values().shape[-2:]
1810
1811        if max_row_nnz is None:
1812            max_row_nnz = triton.next_power_of_2(n)
1813        else:
1814            max_row_nnz = triton.next_power_of_2(max_row_nnz)
1815
1816        crow_indices = input.crow_indices().unsqueeze(0).flatten(0, -2)
1817        # reshape values from
1818        # (b1, ..., bn, nnz, row_block, col_block) to
1819        # (b1 * ... * bn, row_block, nnz * col_block).
1820        # This simplifies batch dim manipulation and unlocks
1821        # the possibility to access all nnzs in any given row.
1822        if input.values().transpose(-3, -2).is_contiguous():
1823            # Need to clone to avoid `contiguous` returning a view.
1824            values = input.values().clone()
1825        else:
1826            values = input.values()
1827        values = (
1828            values.transpose(-3, -2)
1829            .contiguous()
1830            .unsqueeze(0)
1831            .flatten(0, -4)
1832            .reshape(-1, row_block, nnz * col_block)
1833        )
1834        full_grid = (values.shape[0], row_block, m // row_block)
1835        grid_blocks = None
1836        tensor_dims_map = {
1837            # We span nnz number of blocks, not nnz + 1,
1838            # hence crow_indices[..., :-1]
1839            crow_indices[..., :-1]: (0, None, -1),
1840            values: (0, None, None),
1841        }
1842
1843        def kernel(grid, *sliced_tensors):
1844            _bsr_softmax_kernel[grid](
1845                *ptr_stride_extractor(*sliced_tensors),
1846                row_block,
1847                col_block,
1848                max_row_nnz,
1849                # Triton's max numel is bounded by 2 ** 17.
1850                min(2**17, max_row_nnz),
1851            )
1852
1853        launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)
1854
1855        values = (
1856            values.reshape(-1, row_block, nnz, col_block)
1857            .transpose(-3, -2)
1858            .reshape(*input.values().shape)
1859        )
1860
1861        return torch.sparse_compressed_tensor(
1862            input.crow_indices().clone(),
1863            input.col_indices().clone(),
1864            values,
1865            size=input.shape,
1866            layout=input.layout,
1867        )
1868
1869    def _scaled_dot_product_attention(
1870        query: torch.Tensor,
1871        key: torch.Tensor,
1872        value: torch.Tensor,
1873        attn_mask: Optional[torch.Tensor],
1874        dropout_p: float = 0.0,
1875        is_causal: bool = False,
1876        scale: Optional[float] = None,
1877    ):
1878        f_name = "_scaled_dot_product_attention"
1879        check(not is_causal, f"{f_name}(): is_causal == True is not supported.")
1880        check(attn_mask is not None, f"{f_name}(): attn_mask == None is not supported.")
1881        assert attn_mask is not None
1882
1883        check(
1884            attn_mask.layout == torch.sparse_bsr,
1885            f"{f_name}(): "
1886            f"attn_mask.layout must be {torch.sparse_bsr}, but got "
1887            f"attn_mask.layout == {attn_mask.layout}.",
1888        )
1889
1890        check_device(f_name, key, query.device)
1891        check_device(f_name, value, query.device)
1892        check_device(f_name, attn_mask, query.device)
1893
1894        check_dtype(f_name, key, query.dtype)
1895        check_dtype(f_name, value, query.dtype)
1896        if attn_mask.dtype is not torch.bool:
1897            check_dtype(f_name, attn_mask, query.dtype)
1898
1899        sdpa = sampled_addmm(
1900            attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False
1901        )
1902        if scale is None and query.size(-1) == 0 or scale == 0.0:
1903            check(
1904                False,
1905                f"{f_name}(): current value of scale == {scale} "
1906                "results in division by zero.",
1907            )
1908        scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
1909        sdpa.values().mul_(scale_factor)
1910        sdpa = bsr_softmax(sdpa)
1911        torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True)
1912        sdpa = bsr_dense_mm(sdpa, value)
1913        return sdpa
1914
1915    @triton.jit
1916    def _scatter_mm2_kernel(
1917        M: tl.constexpr,
1918        K: tl.constexpr,
1919        N: tl.constexpr,
1920        blocks_ptr,
1921        blocks_stride_P,
1922        blocks_stride_M,
1923        blocks_stride_K,
1924        others_ptr,
1925        others_stride_Q,
1926        others_stride_K,
1927        others_stride_N,
1928        accumulators_ptr,
1929        accumulators_stride_R,
1930        accumulators_stride_M,
1931        accumulators_stride_N,
1932        pq_offsets_ptr,
1933        pq_offsets_stride,
1934        pq_ptr,
1935        pq_stride_T,
1936        pq_stride_1,
1937        dot_out_dtype: tl.constexpr,
1938        TILE_M: tl.constexpr,
1939        TILE_N: tl.constexpr,
1940        allow_tf32: tl.constexpr,
1941    ):
1942        Ms = M // TILE_M
1943        Ns = N // TILE_N
1944
1945        pid_t = tl.program_id(axis=0)
1946
1947        pid = tl.program_id(axis=1)
1948        pid_m = pid // Ms
1949        pid_n = pid % Ms
1950
1951        rm = pid_m * TILE_M + tl.arange(0, TILE_M)
1952        rn = pid_n * TILE_N + tl.arange(0, TILE_N)
1953        rk = tl.arange(0, K)
1954
1955        A_ptr = blocks_ptr + (
1956            rm[:, None] * blocks_stride_M + rk[None, :] * blocks_stride_K
1957        )
1958        B_ptr = others_ptr + (
1959            rk[:, None] * others_stride_K + rn[None, :] * others_stride_N
1960        )
1961
1962        g0 = tl.load(pq_offsets_ptr + pid_t * pq_offsets_stride)
1963        g1 = tl.load(pq_offsets_ptr + (pid_t + 1) * pq_offsets_stride)
1964
1965        if g0 == g1:
1966            return
1967
1968        acc_block = tl.zeros((TILE_M, TILE_N), dtype=dot_out_dtype)
1969
1970        for i in range(g0, g1):
1971            p = tl.load(pq_ptr + i * pq_stride_T)
1972            q = tl.load(pq_ptr + i * pq_stride_T + pq_stride_1)
1973            A = tl.load(A_ptr + p * blocks_stride_P)
1974            B = tl.load(B_ptr + q * others_stride_Q)
1975            acc_block += tl.dot(A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
1976
1977        C_ptr = (
1978            accumulators_ptr
1979            + pid_t * accumulators_stride_R
1980            + (
1981                rm[:, None] * accumulators_stride_M
1982                + rn[None, :] * accumulators_stride_N
1983            )
1984        )
1985        tl.store(C_ptr, acc_block.to(accumulators_ptr.dtype.element_ty))
1986
1987    def _scatter_mm2(
1988        blocks: torch.Tensor,
1989        others: torch.Tensor,
1990        pq_offsets: torch.Tensor,
1991        pq_indices: torch.Tensor,
1992        accumulators: torch.Tensor,
1993    ):
1994        P, M, K = blocks.shape
1995        Q, _, N = others.shape
1996        R, _, _ = accumulators.shape
1997
1998        meta = dict(
1999            TILE_M=max(16, M // 4), TILE_N=max(16, N // 4), num_stages=1, num_warps=2
2000        )
2001
2002        def grid(META):
2003            return (
2004                pq_offsets.shape[0] - 1,
2005                triton.cdiv(M, META["TILE_M"]) * triton.cdiv(N, META["TILE_N"]),
2006                1,
2007            )
2008
2009        dot_out_dtype = {
2010            torch.float16: tl.float32,
2011            torch.bfloat16: tl.float32,
2012            torch.float32: tl.float64,
2013            torch.float64: tl.float64,
2014        }[accumulators.dtype]
2015        if "allow_tf32" not in meta:
2016            meta.update(allow_tf32=dot_out_dtype == tl.float32)
2017        _scatter_mm2_kernel[grid](
2018            M,
2019            K,
2020            N,
2021            blocks,
2022            blocks.stride(0),
2023            blocks.stride(1),
2024            blocks.stride(2),
2025            others,
2026            others.stride(0),
2027            others.stride(1),
2028            others.stride(2),
2029            accumulators,
2030            accumulators.stride(0),
2031            accumulators.stride(1),
2032            accumulators.stride(2),
2033            pq_offsets,
2034            pq_offsets.stride(0),
2035            pq_indices,
2036            pq_indices.stride(0),
2037            pq_indices.stride(1),
2038            dot_out_dtype=dot_out_dtype,
2039            **meta,
2040        )
2041
2042    @triton.jit
2043    def _scatter_mm6_kernel(
2044        nbatches,
2045        Ms,
2046        Ks: tl.constexpr,
2047        N,
2048        blocks_ptr,
2049        blocks_stride_P,
2050        blocks_stride_M,
2051        blocks_stride_K,
2052        others_ptr,
2053        others_stride_B,
2054        others_stride_K,
2055        others_stride_N,
2056        accumulators_ptr,
2057        accumulators_stride_B,
2058        accumulators_stride_M,
2059        accumulators_stride_N,
2060        c_indices_ptr,
2061        r_offsets_ptr,
2062        p_offsets_ptr,
2063        q_offsets_ptr,
2064        is_compressed: tl.constexpr,
2065        dot_out_dtype: tl.constexpr,
2066        SPLIT_N: tl.constexpr,
2067        TILE_M: tl.constexpr,
2068        TILE_N: tl.constexpr,
2069        GROUP_SIZE: tl.constexpr,
2070        allow_tf32: tl.constexpr,
2071    ):
2072        Ns = N // SPLIT_N
2073        BLOCKS_M = Ms // TILE_M
2074        BLOCKS_N = Ns // TILE_N
2075
2076        pid_t_ = tl.program_id(axis=0)
2077        pid = tl.program_id(axis=1)
2078        pid_b = pid_t_ % nbatches
2079        pid_t = pid_t_ // nbatches
2080
2081        num_pid_in_group = GROUP_SIZE * BLOCKS_N
2082        group_id = pid // num_pid_in_group
2083        first_pid_m = group_id * GROUP_SIZE
2084        group_size_m = min(BLOCKS_M - first_pid_m, GROUP_SIZE)
2085        pid_m = first_pid_m + (pid % group_size_m)
2086        pid_n = (pid % num_pid_in_group) // group_size_m
2087
2088        rm = pid_m * TILE_M + tl.arange(0, TILE_M)
2089        rn = pid_n * TILE_N + tl.arange(0, TILE_N)
2090        rk = tl.arange(0, Ks)
2091        A_ptr = blocks_ptr + (
2092            rm[:, None] * blocks_stride_M + rk[None, :] * blocks_stride_K
2093        )
2094        B_ptr = (
2095            others_ptr
2096            + pid_b * others_stride_B
2097            + (rk[:, None] * others_stride_K + rn[None, :] * others_stride_N)
2098        )
2099
2100        # When is_compressed is True, r is the only variable that
2101        # depends on pid_t. This property allows sorting r values
2102        # before calling the kernel. The sorting of r is equivalent to
2103        # defining swizzle operator outside of the kernel.
2104        r = tl.load(r_offsets_ptr + pid_t)
2105
2106        if is_compressed:
2107            m = (r // N) // Ms
2108            n = (r % N) // Ns
2109            r0 = tl.load(c_indices_ptr + m)
2110            r1 = tl.load(c_indices_ptr + m + 1)
2111            g0 = n * r1 + (SPLIT_N - n) * r0
2112            nnz = r1 - r0
2113        else:
2114            g0 = tl.load(c_indices_ptr + pid_t)
2115            g1 = tl.load(c_indices_ptr + pid_t + 1)
2116            nnz = g1 - g0
2117
2118        q_ptr = q_offsets_ptr + g0
2119        acc_block = tl.zeros((TILE_M, TILE_N), dtype=dot_out_dtype)
2120
2121        if is_compressed:
2122            A_ptr += r0 * blocks_stride_P  # type: ignore[possibly-undefined]
2123            for _ in range(nnz):
2124                q = tl.load(q_ptr)
2125                B = tl.load(B_ptr + q)
2126                A = tl.load(A_ptr)
2127                acc_block += tl.dot(
2128                    A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32
2129                )
2130                A_ptr += blocks_stride_P
2131                q_ptr += 1
2132        else:
2133            p_ptr = p_offsets_ptr + g0
2134            for _ in range(nnz):
2135                q = tl.load(q_ptr)
2136                B = tl.load(B_ptr + q)
2137                p = tl.load(p_ptr)
2138                A = tl.load(A_ptr + p * blocks_stride_P)
2139                p_ptr += 1
2140                q_ptr += 1
2141                acc_block += tl.dot(
2142                    A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32
2143                )
2144
2145        C_ptr = (
2146            accumulators_ptr
2147            + r
2148            + pid_b * accumulators_stride_B
2149            + (
2150                rm[:, None] * accumulators_stride_M
2151                + rn[None, :] * accumulators_stride_N
2152            )
2153        )
2154        tl.store(C_ptr, acc_block.to(accumulators_ptr.dtype.element_ty))
2155
2156    def _scatter_mm6(
2157        blocks: torch.Tensor,
2158        others: torch.Tensor,
2159        c_indices: torch.Tensor,
2160        r_offsets: torch.Tensor,
2161        p_offsets: torch.Tensor,
2162        q_offsets: torch.Tensor,
2163        meta: dict,
2164        accumulators: torch.Tensor,
2165        force_contiguous: bool = True,
2166    ):
2167        SPLIT_N = meta["SPLIT_N"]
2168        P, Ms, Ks = blocks.shape
2169        B, K_, N = others.shape
2170        B_, M, N_ = accumulators.shape
2171        assert N_ == N
2172        Ns = N // SPLIT_N
2173        assert B_ == B
2174
2175        def grid(META):
2176            return (
2177                r_offsets.shape[0] * B,
2178                triton.cdiv(Ms, META["TILE_M"]) * triton.cdiv(Ns, META["TILE_N"]),
2179            )
2180
2181        dot_out_dtype = {
2182            torch.float16: tl.float32,
2183            torch.bfloat16: tl.float32,
2184            torch.float32: tl.float64,
2185            torch.float64: tl.float64,
2186        }[accumulators.dtype]
2187        if "allow_tf32" not in meta:
2188            meta.update(allow_tf32=dot_out_dtype == tl.float32)
2189
2190        assert c_indices.stride(0) == 1
2191        assert r_offsets.stride(0) == 1
2192        assert p_offsets.stride(0) == 1
2193        assert q_offsets.stride(0) == 1
2194
2195        # Re non-contiguous tensor arguments. Sometimes triton kernel
2196        # launches may fail with
2197        #
2198        #   RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered
2199        #
2200        # that appears to be case when the size of a non-contiguous
2201        # tensor argument is larger than a certain threshold. Could
2202        # this be related to shared memory or L1 cache size of a GPU
2203        # card? In anycase, ensuring that tensor arguments are
2204        # contiguous seems to avoid the above exception. So, in the
2205        # following we'll always convert tensor arguments to
2206        # C-contiguous tensors.
2207
2208        if force_contiguous:
2209            blocks = blocks.contiguous()
2210            others = others.contiguous()
2211            if not accumulators.is_contiguous():
2212                accumulators_ = accumulators.contiguous()
2213            else:
2214                accumulators_ = accumulators
2215        else:
2216            accumulators_ = accumulators
2217
2218        _scatter_mm6_kernel[grid](
2219            B,
2220            Ms,
2221            Ks,
2222            N,
2223            blocks,
2224            blocks.stride(0),
2225            blocks.stride(1),
2226            blocks.stride(2),
2227            others,
2228            others.stride(0),
2229            others.stride(1),
2230            others.stride(2),
2231            accumulators_,
2232            accumulators_.stride(0),
2233            accumulators_.stride(1),
2234            accumulators_.stride(2),
2235            c_indices,
2236            r_offsets,
2237            p_offsets,
2238            q_offsets,
2239            dot_out_dtype=dot_out_dtype,
2240            **meta,
2241        )
2242
2243        if force_contiguous and not accumulators.is_contiguous():
2244            accumulators.copy_(accumulators_)
2245
2246    @triton.jit
2247    def _bsr_strided_addmm_kernel(
2248        # values prologue
2249        values_ptr,
2250        values_batch_stride,
2251        values_nnz_stride,
2252        values_row_block_stride,
2253        values_col_block_stride,
2254        # values epilogue
2255        # crow_indices prologue
2256        crow_indices_ptr,
2257        crow_indices_batch_stride,
2258        crow_indices_stride,
2259        # crow_indices epilogue
2260        # col_indices prologue
2261        col_indices_ptr,
2262        col_indices_batch_stride,
2263        col_indices_stride,
2264        # col_indices epilogue
2265        # input prologue
2266        input_ptr,
2267        input_batch_stride,
2268        input_tiled_row_stride,
2269        input_tiled_col_stride,
2270        input_row_block_stride,
2271        input_col_block_stride,
2272        # input epilogue
2273        # dense prologue
2274        dense_ptr,
2275        dense_batch_stride,
2276        dense_tiled_row_stride,
2277        dense_tiled_col_stride,
2278        dense_row_block_stride,
2279        dense_col_block_stride,
2280        # dense epilogue
2281        # output prologue
2282        output_ptr,
2283        output_batch_stride,
2284        output_tiled_row_stride,
2285        output_tiled_col_stride,
2286        output_row_block_stride,
2287        output_col_block_stride,
2288        # output epilogue
2289        beta,
2290        alpha,
2291        beta_is_one: tl.constexpr,
2292        beta_is_nonzero: tl.constexpr,
2293        alpha_is_one: tl.constexpr,
2294        BLOCKSIZE_ROW: tl.constexpr,
2295        BLOCKSIZE_COL: tl.constexpr,
2296        BLOCKSIZE_INNER: tl.constexpr,
2297        acc_dtype: tl.constexpr,
2298        allow_tf32: tl.constexpr,
2299        GROUP_SIZE_ROW: tl.constexpr,
2300        SPLIT_N: tl.constexpr,
2301    ):
2302        batch_pid = tl.program_id(axis=2)
2303        row_block_pid = tl.program_id(axis=0)
2304        col_block_pid = tl.program_id(axis=1)
2305        n_block_rows = tl.num_programs(axis=0)
2306        n_block_cols = tl.num_programs(axis=1)
2307
2308        row_block_pid, col_block_pid = tl.swizzle2d(
2309            row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW
2310        )
2311
2312        crow_indices_offset_ptr = (
2313            crow_indices_ptr
2314            + crow_indices_batch_stride * batch_pid
2315            + crow_indices_stride * row_block_pid
2316        )
2317        nnz_offset = tl.load(crow_indices_offset_ptr)
2318        nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)
2319
2320        # Compute nnz for the row with number row_block_pid.
2321        row_nnz = nnz_offset_next - nnz_offset
2322
2323        row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
2324        inner_block_arange = tl.arange(0, BLOCKSIZE_INNER)
2325        col_block_arange = tl.arange(0, BLOCKSIZE_COL)
2326
2327        if beta_is_nonzero:
2328            # Pointers are set to exact write-to locations
2329            input_ptrs = (
2330                input_ptr
2331                + input_batch_stride * batch_pid
2332                + input_tiled_row_stride * row_block_pid
2333                + input_tiled_col_stride * col_block_pid
2334                + input_row_block_stride * row_block_arange[:, None]
2335                + input_col_block_stride * col_block_arange[None, :]
2336            )
2337
2338        # Pointers are set to the first block of the current row.
2339        values_block_ptrs = (
2340            values_ptr
2341            + values_batch_stride * batch_pid
2342            + values_nnz_stride * nnz_offset
2343            + values_row_block_stride * row_block_arange[:, None]
2344            + values_col_block_stride * inner_block_arange[None, :]
2345        )
2346
2347        # NOTE: dense is advanced into all dimensions but the tiled row one.
2348        # That will be advanced in the loop according to values in col_indices.
2349        dense_block_ptrs = (
2350            dense_ptr
2351            + dense_batch_stride * batch_pid
2352            + dense_tiled_col_stride * col_block_pid
2353            + dense_row_block_stride * inner_block_arange[:, None]
2354            + dense_col_block_stride * col_block_arange[None, :]
2355        )
2356
2357        # Pointers are set to exact write-to locations
2358        output_ptrs = (
2359            output_ptr
2360            + output_batch_stride * batch_pid
2361            + output_tiled_row_stride * row_block_pid
2362            + output_tiled_col_stride * col_block_pid
2363            + output_row_block_stride * row_block_arange[:, None]
2364            + output_col_block_stride * col_block_arange[None, :]
2365        )
2366
2367        # Set pointer to the first nonzero element in the current row
2368        col_index_nnz_ptr = (
2369            col_indices_ptr
2370            + col_indices_batch_stride * batch_pid
2371            + col_indices_stride * nnz_offset
2372        )
2373
2374        # alpha is never 0
2375        if beta_is_nonzero:
2376            output_acc_block = tl.load(input_ptrs).to(acc_dtype)  # type: ignore[possibly-undefined]
2377            if not (beta_is_one and alpha_is_one):
2378                beta_alpha = beta / alpha
2379                output_acc_block *= beta_alpha
2380        else:
2381            output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)
2382
2383        for _ in range(row_nnz):
2384            values_block = tl.load(values_block_ptrs)
2385
2386            # find which row of dense needs to get loaded
2387            # for multiplication with values_block.
2388            dense_row_idx = tl.load(col_index_nnz_ptr)
2389            dense_block = tl.load(
2390                dense_block_ptrs + dense_tiled_row_stride * dense_row_idx
2391            )
2392
2393            # do block mm
2394            output_acc_block += tl.dot(
2395                values_block, dense_block, allow_tf32=allow_tf32, out_dtype=acc_dtype
2396            )
2397
2398            # move val/col_index ptrs to the next block in the row
2399            values_block_ptrs += values_nnz_stride
2400            col_index_nnz_ptr += col_indices_stride
2401
2402        if not alpha_is_one:
2403            output_acc_block *= alpha
2404
2405        # write back the result
2406        tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))
2407
2408else:
2409    bsr_softmax = None  # type: ignore[assignment]
2410    bsr_dense_mm = None  # type: ignore[assignment]
2411    sampled_addmm = None  # type: ignore[assignment]
2412    _scaled_dot_product_attention = None  # type: ignore[assignment]
2413    _scatter_mm2 = None  # type: ignore[assignment]
2414    _scatter_mm6 = None  # type: ignore[assignment]
2415    _bsr_strided_addmm_kernel = None  # type: ignore[assignment]
2416