xref: /aosp_15_r20/external/pytorch/torch/ao/nn/sparse/quantized/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import threading
3
4
5__all__ = ["LinearBlockSparsePattern"]
6
7
8def _is_valid_linear_block_sparse_pattern(row_block_size, col_block_size):
9    return (row_block_size == 1 and col_block_size == 4) or (
10        row_block_size == 8 and col_block_size == 1
11    )
12
13
14# This is a stop-gap measure as current flow does not allow module
15# specific block sparse pattern.
16# Infact there is no way to convey sparse pattern via module config
17# of quantization flow. Thus using the global context to convey
18# sparsity pattern.
19# Once the flow supports it, this should be removed.
20class LinearBlockSparsePattern:
21    rlock = threading.RLock()
22    row_block_size = 1
23    col_block_size = 4
24    prev_row_block_size = 1
25    prev_col_block_size = 4
26
27    def __init__(self, row_block_size=1, col_block_size=4):
28        assert _is_valid_linear_block_sparse_pattern(row_block_size, col_block_size)
29        LinearBlockSparsePattern.rlock.acquire()
30        LinearBlockSparsePattern.prev_row_block_size = (
31            LinearBlockSparsePattern.row_block_size
32        )
33        LinearBlockSparsePattern.prev_col_block_size = (
34            LinearBlockSparsePattern.col_block_size
35        )
36        LinearBlockSparsePattern.row_block_size = row_block_size
37        LinearBlockSparsePattern.col_block_size = col_block_size
38
39    def __enter__(self):
40        pass
41
42    def __exit__(self, exc_type, exc_value, backtrace):
43        LinearBlockSparsePattern.row_block_size = (
44            LinearBlockSparsePattern.prev_row_block_size
45        )
46        LinearBlockSparsePattern.col_block_size = (
47            LinearBlockSparsePattern.prev_col_block_size
48        )
49        LinearBlockSparsePattern.rlock.release()
50
51    @staticmethod
52    def block_size():
53        return (
54            LinearBlockSparsePattern.row_block_size,
55            LinearBlockSparsePattern.col_block_size,
56        )
57