1# mypy: allow-untyped-defs 2import threading 3 4 5__all__ = ["LinearBlockSparsePattern"] 6 7 8def _is_valid_linear_block_sparse_pattern(row_block_size, col_block_size): 9 return (row_block_size == 1 and col_block_size == 4) or ( 10 row_block_size == 8 and col_block_size == 1 11 ) 12 13 14# This is a stop-gap measure as current flow does not allow module 15# specific block sparse pattern. 16# Infact there is no way to convey sparse pattern via module config 17# of quantization flow. Thus using the global context to convey 18# sparsity pattern. 19# Once the flow supports it, this should be removed. 20class LinearBlockSparsePattern: 21 rlock = threading.RLock() 22 row_block_size = 1 23 col_block_size = 4 24 prev_row_block_size = 1 25 prev_col_block_size = 4 26 27 def __init__(self, row_block_size=1, col_block_size=4): 28 assert _is_valid_linear_block_sparse_pattern(row_block_size, col_block_size) 29 LinearBlockSparsePattern.rlock.acquire() 30 LinearBlockSparsePattern.prev_row_block_size = ( 31 LinearBlockSparsePattern.row_block_size 32 ) 33 LinearBlockSparsePattern.prev_col_block_size = ( 34 LinearBlockSparsePattern.col_block_size 35 ) 36 LinearBlockSparsePattern.row_block_size = row_block_size 37 LinearBlockSparsePattern.col_block_size = col_block_size 38 39 def __enter__(self): 40 pass 41 42 def __exit__(self, exc_type, exc_value, backtrace): 43 LinearBlockSparsePattern.row_block_size = ( 44 LinearBlockSparsePattern.prev_row_block_size 45 ) 46 LinearBlockSparsePattern.col_block_size = ( 47 LinearBlockSparsePattern.prev_col_block_size 48 ) 49 LinearBlockSparsePattern.rlock.release() 50 51 @staticmethod 52 def block_size(): 53 return ( 54 LinearBlockSparsePattern.row_block_size, 55 LinearBlockSparsePattern.col_block_size, 56 ) 57