1# Owner(s): ["module: unknown"] 2 3import copy 4import io 5import logging 6from itertools import product 7 8import numpy as np 9 10import torch 11import torch.ao.quantization as tq 12from torch import nn 13from torch.ao.pruning.sparsifier.utils import fqn_to_module 14from torch.testing._internal.common_quantized import ( 15 override_cpu_allocator_for_qnnpack, 16 override_qengines, 17 qengine_is_fbgemm, 18 qengine_is_onednn, 19 qengine_is_qnnpack, 20 qengine_is_x86, 21) 22from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase 23 24 25# TODO: Once more test files are created, move the contents to a ao folder. 26 27logging.basicConfig( 28 format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO 29) 30 31 32class TestQuantizedSparseKernels(TestCase): 33 @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") 34 @override_qengines 35 def test_sparse_qlinear(self): 36 batch_size = 12 37 input_channels = 16 38 output_channels = 4 39 decimal_val = 4 40 row_block_size = 1 41 col_block_size = 4 42 43 # X86 implementation of sparse ops in qnnpack only support 44 # block pattern 1x4. 45 # arm kernels have support for both 1x4 and 8x1. 46 # This distinction is only because x86 implementations exist 47 # only to enable testing of integration path. 48 # We do plan to add 8x1 as well so that testing does not have to 49 # special case like this. At the moment it is deprioritized due 50 # to other higher priority works. 51 if qengine_is_qnnpack() and not (row_block_size == 1 and col_block_size == 4): 52 return 53 # ONEDNN and X86 do not support this yet 54 if qengine_is_onednn() or qengine_is_x86(): 55 return 56 57 dense_prepack = torch.ops.quantized.linear_prepack 58 dense_qlinear = torch.ops.quantized.linear 59 dense_qlinear_dynamic = torch.ops.quantized.linear_dynamic 60 61 sparse_prepack = torch.ops.sparse.qlinear_prepack 62 sparse_qlinear = torch.ops.sparse.qlinear 63 sparse_qlinear_dynamic = torch.ops.sparse.qlinear_dynamic 64 65 X_scale = 0.2 66 X_zp = 2 67 X_fp32 = torch.randn(batch_size, input_channels, dtype=torch.float32) 68 float_bias = torch.randn(output_channels, dtype=torch.float32) 69 70 W_scales = torch.rand(output_channels, dtype=torch.float32) 71 W_zps = torch.zeros(output_channels, dtype=torch.int32) 72 W_fp32 = torch.randn(output_channels, input_channels, dtype=torch.float32) 73 74 with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()): 75 X_q = torch.quantize_per_tensor( 76 X_fp32, scale=X_scale, zero_point=X_zp, dtype=torch.quint8 77 ) 78 79 for use_channelwise, dynamic_mode in product([True, False], [True, False]): 80 if qengine_is_fbgemm() and dynamic_mode: 81 logging.info("dynamic sparse qlinear is only available in qnnpack") 82 continue 83 if qengine_is_qnnpack() and not dynamic_mode: 84 logging.info("static sparse qlinear is only available in fbgemm") 85 continue 86 if use_channelwise: 87 W_q = torch.quantize_per_channel( 88 W_fp32, 89 scales=W_scales, 90 zero_points=W_zps, 91 axis=0, 92 dtype=torch.qint8, 93 ) 94 else: 95 W_q = torch.quantize_per_tensor( 96 W_fp32, 97 scale=W_scales[0], 98 zero_point=W_zps[0], 99 dtype=torch.qint8, 100 ) 101 102 Y_scale = 1.1234 103 Y_zp = 5 104 W_prepack_dense = dense_prepack(W_q, float_bias) 105 W_prepack_sparse = sparse_prepack( 106 W_q, float_bias, row_block_size, col_block_size 107 ) 108 109 if dynamic_mode: 110 Y = sparse_qlinear_dynamic(X_fp32, W_prepack_sparse) 111 Y_ref = dense_qlinear_dynamic(X_fp32, W_prepack_dense) 112 113 np.testing.assert_array_almost_equal( 114 Y_ref.numpy(), Y.numpy(), decimal=decimal_val 115 ) 116 else: 117 Y_q = sparse_qlinear(X_q, W_prepack_sparse, Y_scale, Y_zp) 118 Y_q_ref = dense_qlinear(X_q, W_prepack_dense, Y_scale, Y_zp) 119 120 np.testing.assert_array_almost_equal( 121 Y_q_ref.int_repr().numpy(), 122 Y_q.int_repr().numpy(), 123 decimal=decimal_val, 124 ) 125 126 127def _sparse_layer_test_helper( 128 model_class, 129 sparse_mapping, 130 ref_mapping, 131 qconfig_dict, 132 fqn_to_check, 133 test_class, 134 test_scripting, 135): 136 # SET UP TEST PARAMETERS, INPUTS AND WEIGHTS 137 # ------------------------------------------ 138 batch_size = 12 139 input_channels = 4 140 output_channels = 7 141 model = model_class(input_channels, output_channels) 142 143 # For sparse kernels both the activation and weight ZP = 0 144 X_scale = 0.2 145 X_zp = 2 146 W_scale = 1e-2 147 W_zp = 0 148 149 X_fp32 = torch.randn(batch_size, input_channels, dtype=torch.float32) 150 float_bias = torch.randn(output_channels, dtype=torch.float32) 151 152 # generate a weight which we'll insert into the model 153 W_fp32 = torch.randn(output_channels, input_channels, dtype=torch.float32) 154 mask = torch.randint(0, 2, W_fp32.shape) 155 W_fp32 *= mask 156 with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()): 157 X_q = torch.quantize_per_tensor( 158 X_fp32, scale=X_scale, zero_point=X_zp, dtype=torch.quint8 159 ) 160 X_fp32 = X_q.dequantize() 161 162 W_q = torch.quantize_per_tensor(W_fp32, W_scale, W_zp, torch.qint8) 163 164 # PREPARE MODELS FOR QUANTIZATION 165 # ------------------------------- 166 model.linear.weight = nn.Parameter(W_q.dequantize()) 167 model.eval() 168 169 # Add `sparse_params` to the model. The test for correct 170 # sparse_param addition is in the sparsifier tests 171 model.linear.sparse_params = {"sparse_block_shape": (1, 4)} 172 173 # generate model versions 174 qmodel = copy.deepcopy(model) 175 sqmodel = copy.deepcopy(model) 176 177 # generate model versions and apply qconfigs 178 tq.propagate_qconfig_(qmodel, qconfig_dict) 179 tq.propagate_qconfig_(sqmodel, qconfig_dict) 180 181 tq.prepare(qmodel, inplace=True) 182 tq.prepare(sqmodel, inplace=True) 183 184 # calibrate 185 with torch.no_grad(): 186 qmodel(X_fp32) 187 sqmodel(X_fp32) 188 189 # ACTUAL TESTING BEGINS HERE 190 # -------------------------- 191 192 # Make sure the quantization parameters are computed the same way 193 qparams = qmodel.linear.qconfig.weight().calculate_qparams() 194 sqparams = sqmodel.linear.qconfig.weight().calculate_qparams() 195 test_class.assertEqual(qparams, sqparams) 196 197 sqmodule_to_check = fqn_to_module(sqmodel, fqn_to_check) 198 sqmodule_start_class = sqmodule_to_check.__class__ 199 sqmodule_expected_converted_class = sparse_mapping[sqmodule_start_class] 200 201 qmodule_to_check = fqn_to_module(qmodel, fqn_to_check) 202 qmodule_start_class = qmodule_to_check.__class__ 203 qmodule_expected_converted_class = ref_mapping[qmodule_start_class] 204 205 # need to determine whether dynamic quantization is being performed since 206 # input dtype will be different at the end 207 is_dynamic = isinstance( 208 qmodule_to_check.activation_post_process, tq.PlaceholderObserver 209 ) 210 211 tq.convert(sqmodel, inplace=True, mapping=sparse_mapping) 212 tq.convert(qmodel, inplace=True, mapping=ref_mapping) 213 214 # this code is a duplicate of above since the references do not 215 # update to the post-convert modules 216 sqmodule_to_check = fqn_to_module(sqmodel, fqn_to_check) 217 qmodule_to_check = fqn_to_module(qmodel, fqn_to_check) 218 219 # check that the modules were converted as expected 220 assert isinstance( 221 sqmodule_to_check, sqmodule_expected_converted_class 222 ), "Convert failed" 223 assert isinstance( 224 qmodule_to_check, qmodule_expected_converted_class 225 ), "Mapping failed" 226 227 row_block_size, col_block_size = sqmodel.linear._packed_params._weight_bias()[ 228 2: 229 ] 230 assert row_block_size == 1 and col_block_size == 4 231 232 # only run during serialization/deserialization tests 233 # makes sure script/save/load doesn't malform the sqmodel 234 if test_scripting: 235 scripted_sqmodel = torch.jit.script(sqmodel) 236 scripted_sqmodel.eval() 237 buffer = io.BytesIO() 238 torch.jit.save(scripted_sqmodel, buffer) 239 buffer.seek(0) 240 sqmodel = torch.jit.load(buffer) 241 242 # use correct input dtype 243 if is_dynamic: 244 Y_ref = qmodel(X_fp32) 245 Y_hat = sqmodel(X_fp32) 246 test_class.assertEqual(Y_ref, Y_hat) 247 else: 248 Y_ref = qmodel(X_q) 249 Y_hat = sqmodel(X_q) 250 test_class.assertEqual(Y_ref.dequantize(), Y_hat.dequantize()) 251 252 253class SparseQuantizedModel(nn.Module): 254 def __init__(self, in_channels, out_channels): 255 super().__init__() 256 self.linear = nn.Linear(in_channels, out_channels) 257 258 def forward(self, x): 259 return self.linear(x) 260 261 262class TestQuantizedSparseLayers(TestCase): 263 @override_qengines 264 @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") 265 def test_sparse_qlinear(self): 266 # Note: At the moment, for sparse kernels 267 # fbgemm supports only static quantized sparse linear 268 # qnnpack supports only dynamically quantized sparse linear 269 # Hence we have two different tests. 270 # fbgemm tests static flow, qnnpack tests dynamic. 271 # Should be unified later on and tests should be fixed 272 # appropriately. 273 model_class = SparseQuantizedModel 274 fqn_to_check = "linear" 275 if qengine_is_fbgemm(): 276 sparse_mapping = tq.get_default_static_sparse_quant_module_mappings() 277 ref_mapping = tq.get_default_static_quant_module_mappings() 278 qconfig_dict = {nn.Linear: tq.get_default_qconfig("fbgemm")} 279 elif qengine_is_qnnpack(): 280 sparse_mapping = tq.get_default_dynamic_sparse_quant_module_mappings() 281 ref_mapping = tq.get_default_dynamic_quant_module_mappings() 282 qconfig_dict = {nn.Linear: tq.qconfig.default_dynamic_qconfig} 283 else: 284 return 285 286 _sparse_layer_test_helper( 287 model_class=model_class, 288 sparse_mapping=sparse_mapping, 289 ref_mapping=ref_mapping, 290 qconfig_dict=qconfig_dict, 291 fqn_to_check=fqn_to_check, 292 test_class=self, 293 test_scripting=False, 294 ) 295 296 @override_qengines 297 @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") 298 def test_sparse_qlinear_serdes(self): 299 # Note: At the moment, for sparse kernels 300 # fbgemm supports only static quantized sparse linear 301 # qnnpack supports only dynamically quantized sparse linear 302 # Hence we have two different tests. 303 # fbgemm tests static flow, qnnpack tests dynamic. 304 # Should be unified later on and tests should be fixed 305 # appropriately. 306 model_class = SparseQuantizedModel 307 fqn_to_check = "linear" 308 if qengine_is_fbgemm(): 309 sparse_mapping = tq.get_default_static_sparse_quant_module_mappings() 310 ref_mapping = tq.get_default_static_quant_module_mappings() 311 qconfig_dict = {nn.Linear: tq.get_default_qconfig("fbgemm")} 312 elif qengine_is_qnnpack(): 313 sparse_mapping = tq.get_default_dynamic_sparse_quant_module_mappings() 314 ref_mapping = tq.get_default_dynamic_quant_module_mappings() 315 qconfig_dict = {nn.Linear: tq.qconfig.default_dynamic_qconfig} 316 else: 317 return 318 319 _sparse_layer_test_helper( 320 model_class=model_class, 321 sparse_mapping=sparse_mapping, 322 ref_mapping=ref_mapping, 323 qconfig_dict=qconfig_dict, 324 fqn_to_check=fqn_to_check, 325 test_class=self, 326 test_scripting=True, 327 ) 328 329 330if __name__ == "__main__": 331 run_tests() 332