xref: /aosp_15_r20/external/pytorch/test/ao/sparsity/test_kernels.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3import copy
4import io
5import logging
6from itertools import product
7
8import numpy as np
9
10import torch
11import torch.ao.quantization as tq
12from torch import nn
13from torch.ao.pruning.sparsifier.utils import fqn_to_module
14from torch.testing._internal.common_quantized import (
15    override_cpu_allocator_for_qnnpack,
16    override_qengines,
17    qengine_is_fbgemm,
18    qengine_is_onednn,
19    qengine_is_qnnpack,
20    qengine_is_x86,
21)
22from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
23
24
25# TODO: Once more test files are created, move the contents to a ao folder.
26
27logging.basicConfig(
28    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
29)
30
31
32class TestQuantizedSparseKernels(TestCase):
33    @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")
34    @override_qengines
35    def test_sparse_qlinear(self):
36        batch_size = 12
37        input_channels = 16
38        output_channels = 4
39        decimal_val = 4
40        row_block_size = 1
41        col_block_size = 4
42
43        # X86 implementation of sparse ops in qnnpack only support
44        # block pattern 1x4.
45        # arm kernels have support for both 1x4 and 8x1.
46        # This distinction is only because x86 implementations exist
47        # only to enable testing of integration path.
48        # We do plan to add 8x1 as well so that testing does not have to
49        # special case like this. At the moment it is deprioritized due
50        # to other higher priority works.
51        if qengine_is_qnnpack() and not (row_block_size == 1 and col_block_size == 4):
52            return
53        # ONEDNN and X86 do not support this yet
54        if qengine_is_onednn() or qengine_is_x86():
55            return
56
57        dense_prepack = torch.ops.quantized.linear_prepack
58        dense_qlinear = torch.ops.quantized.linear
59        dense_qlinear_dynamic = torch.ops.quantized.linear_dynamic
60
61        sparse_prepack = torch.ops.sparse.qlinear_prepack
62        sparse_qlinear = torch.ops.sparse.qlinear
63        sparse_qlinear_dynamic = torch.ops.sparse.qlinear_dynamic
64
65        X_scale = 0.2
66        X_zp = 2
67        X_fp32 = torch.randn(batch_size, input_channels, dtype=torch.float32)
68        float_bias = torch.randn(output_channels, dtype=torch.float32)
69
70        W_scales = torch.rand(output_channels, dtype=torch.float32)
71        W_zps = torch.zeros(output_channels, dtype=torch.int32)
72        W_fp32 = torch.randn(output_channels, input_channels, dtype=torch.float32)
73
74        with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()):
75            X_q = torch.quantize_per_tensor(
76                X_fp32, scale=X_scale, zero_point=X_zp, dtype=torch.quint8
77            )
78
79            for use_channelwise, dynamic_mode in product([True, False], [True, False]):
80                if qengine_is_fbgemm() and dynamic_mode:
81                    logging.info("dynamic sparse qlinear is only available in qnnpack")
82                    continue
83                if qengine_is_qnnpack() and not dynamic_mode:
84                    logging.info("static sparse qlinear is only available in fbgemm")
85                    continue
86                if use_channelwise:
87                    W_q = torch.quantize_per_channel(
88                        W_fp32,
89                        scales=W_scales,
90                        zero_points=W_zps,
91                        axis=0,
92                        dtype=torch.qint8,
93                    )
94                else:
95                    W_q = torch.quantize_per_tensor(
96                        W_fp32,
97                        scale=W_scales[0],
98                        zero_point=W_zps[0],
99                        dtype=torch.qint8,
100                    )
101
102                Y_scale = 1.1234
103                Y_zp = 5
104                W_prepack_dense = dense_prepack(W_q, float_bias)
105                W_prepack_sparse = sparse_prepack(
106                    W_q, float_bias, row_block_size, col_block_size
107                )
108
109                if dynamic_mode:
110                    Y = sparse_qlinear_dynamic(X_fp32, W_prepack_sparse)
111                    Y_ref = dense_qlinear_dynamic(X_fp32, W_prepack_dense)
112
113                    np.testing.assert_array_almost_equal(
114                        Y_ref.numpy(), Y.numpy(), decimal=decimal_val
115                    )
116                else:
117                    Y_q = sparse_qlinear(X_q, W_prepack_sparse, Y_scale, Y_zp)
118                    Y_q_ref = dense_qlinear(X_q, W_prepack_dense, Y_scale, Y_zp)
119
120                    np.testing.assert_array_almost_equal(
121                        Y_q_ref.int_repr().numpy(),
122                        Y_q.int_repr().numpy(),
123                        decimal=decimal_val,
124                    )
125
126
127def _sparse_layer_test_helper(
128    model_class,
129    sparse_mapping,
130    ref_mapping,
131    qconfig_dict,
132    fqn_to_check,
133    test_class,
134    test_scripting,
135):
136    # SET UP TEST PARAMETERS, INPUTS AND WEIGHTS
137    # ------------------------------------------
138    batch_size = 12
139    input_channels = 4
140    output_channels = 7
141    model = model_class(input_channels, output_channels)
142
143    # For sparse kernels both the activation and weight ZP = 0
144    X_scale = 0.2
145    X_zp = 2
146    W_scale = 1e-2
147    W_zp = 0
148
149    X_fp32 = torch.randn(batch_size, input_channels, dtype=torch.float32)
150    float_bias = torch.randn(output_channels, dtype=torch.float32)
151
152    # generate a weight which we'll insert into the model
153    W_fp32 = torch.randn(output_channels, input_channels, dtype=torch.float32)
154    mask = torch.randint(0, 2, W_fp32.shape)
155    W_fp32 *= mask
156    with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()):
157        X_q = torch.quantize_per_tensor(
158            X_fp32, scale=X_scale, zero_point=X_zp, dtype=torch.quint8
159        )
160        X_fp32 = X_q.dequantize()
161
162        W_q = torch.quantize_per_tensor(W_fp32, W_scale, W_zp, torch.qint8)
163
164        # PREPARE MODELS FOR QUANTIZATION
165        # -------------------------------
166        model.linear.weight = nn.Parameter(W_q.dequantize())
167        model.eval()
168
169        # Add `sparse_params` to the model. The test for correct
170        # sparse_param addition is in the sparsifier tests
171        model.linear.sparse_params = {"sparse_block_shape": (1, 4)}
172
173        # generate model versions
174        qmodel = copy.deepcopy(model)
175        sqmodel = copy.deepcopy(model)
176
177        # generate model versions and apply qconfigs
178        tq.propagate_qconfig_(qmodel, qconfig_dict)
179        tq.propagate_qconfig_(sqmodel, qconfig_dict)
180
181        tq.prepare(qmodel, inplace=True)
182        tq.prepare(sqmodel, inplace=True)
183
184        # calibrate
185        with torch.no_grad():
186            qmodel(X_fp32)
187            sqmodel(X_fp32)
188
189        # ACTUAL TESTING BEGINS HERE
190        # --------------------------
191
192        # Make sure the quantization parameters are computed the same way
193        qparams = qmodel.linear.qconfig.weight().calculate_qparams()
194        sqparams = sqmodel.linear.qconfig.weight().calculate_qparams()
195        test_class.assertEqual(qparams, sqparams)
196
197        sqmodule_to_check = fqn_to_module(sqmodel, fqn_to_check)
198        sqmodule_start_class = sqmodule_to_check.__class__
199        sqmodule_expected_converted_class = sparse_mapping[sqmodule_start_class]
200
201        qmodule_to_check = fqn_to_module(qmodel, fqn_to_check)
202        qmodule_start_class = qmodule_to_check.__class__
203        qmodule_expected_converted_class = ref_mapping[qmodule_start_class]
204
205        # need to determine whether dynamic quantization is being performed since
206        # input dtype will be different at the end
207        is_dynamic = isinstance(
208            qmodule_to_check.activation_post_process, tq.PlaceholderObserver
209        )
210
211        tq.convert(sqmodel, inplace=True, mapping=sparse_mapping)
212        tq.convert(qmodel, inplace=True, mapping=ref_mapping)
213
214        # this code is a duplicate of above since the references do not
215        # update to the post-convert modules
216        sqmodule_to_check = fqn_to_module(sqmodel, fqn_to_check)
217        qmodule_to_check = fqn_to_module(qmodel, fqn_to_check)
218
219        # check that the modules were converted as expected
220        assert isinstance(
221            sqmodule_to_check, sqmodule_expected_converted_class
222        ), "Convert failed"
223        assert isinstance(
224            qmodule_to_check, qmodule_expected_converted_class
225        ), "Mapping failed"
226
227        row_block_size, col_block_size = sqmodel.linear._packed_params._weight_bias()[
228            2:
229        ]
230        assert row_block_size == 1 and col_block_size == 4
231
232        # only run during serialization/deserialization tests
233        # makes sure script/save/load doesn't malform the sqmodel
234        if test_scripting:
235            scripted_sqmodel = torch.jit.script(sqmodel)
236            scripted_sqmodel.eval()
237            buffer = io.BytesIO()
238            torch.jit.save(scripted_sqmodel, buffer)
239            buffer.seek(0)
240            sqmodel = torch.jit.load(buffer)
241
242        # use correct input dtype
243        if is_dynamic:
244            Y_ref = qmodel(X_fp32)
245            Y_hat = sqmodel(X_fp32)
246            test_class.assertEqual(Y_ref, Y_hat)
247        else:
248            Y_ref = qmodel(X_q)
249            Y_hat = sqmodel(X_q)
250            test_class.assertEqual(Y_ref.dequantize(), Y_hat.dequantize())
251
252
253class SparseQuantizedModel(nn.Module):
254    def __init__(self, in_channels, out_channels):
255        super().__init__()
256        self.linear = nn.Linear(in_channels, out_channels)
257
258    def forward(self, x):
259        return self.linear(x)
260
261
262class TestQuantizedSparseLayers(TestCase):
263    @override_qengines
264    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
265    def test_sparse_qlinear(self):
266        # Note: At the moment, for sparse kernels
267        # fbgemm supports only static quantized sparse linear
268        # qnnpack supports only dynamically quantized sparse linear
269        # Hence we have two different tests.
270        # fbgemm tests static flow, qnnpack tests dynamic.
271        # Should be unified later on and tests should be fixed
272        # appropriately.
273        model_class = SparseQuantizedModel
274        fqn_to_check = "linear"
275        if qengine_is_fbgemm():
276            sparse_mapping = tq.get_default_static_sparse_quant_module_mappings()
277            ref_mapping = tq.get_default_static_quant_module_mappings()
278            qconfig_dict = {nn.Linear: tq.get_default_qconfig("fbgemm")}
279        elif qengine_is_qnnpack():
280            sparse_mapping = tq.get_default_dynamic_sparse_quant_module_mappings()
281            ref_mapping = tq.get_default_dynamic_quant_module_mappings()
282            qconfig_dict = {nn.Linear: tq.qconfig.default_dynamic_qconfig}
283        else:
284            return
285
286        _sparse_layer_test_helper(
287            model_class=model_class,
288            sparse_mapping=sparse_mapping,
289            ref_mapping=ref_mapping,
290            qconfig_dict=qconfig_dict,
291            fqn_to_check=fqn_to_check,
292            test_class=self,
293            test_scripting=False,
294        )
295
296    @override_qengines
297    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
298    def test_sparse_qlinear_serdes(self):
299        # Note: At the moment, for sparse kernels
300        # fbgemm supports only static quantized sparse linear
301        # qnnpack supports only dynamically quantized sparse linear
302        # Hence we have two different tests.
303        # fbgemm tests static flow, qnnpack tests dynamic.
304        # Should be unified later on and tests should be fixed
305        # appropriately.
306        model_class = SparseQuantizedModel
307        fqn_to_check = "linear"
308        if qengine_is_fbgemm():
309            sparse_mapping = tq.get_default_static_sparse_quant_module_mappings()
310            ref_mapping = tq.get_default_static_quant_module_mappings()
311            qconfig_dict = {nn.Linear: tq.get_default_qconfig("fbgemm")}
312        elif qengine_is_qnnpack():
313            sparse_mapping = tq.get_default_dynamic_sparse_quant_module_mappings()
314            ref_mapping = tq.get_default_dynamic_quant_module_mappings()
315            qconfig_dict = {nn.Linear: tq.qconfig.default_dynamic_qconfig}
316        else:
317            return
318
319        _sparse_layer_test_helper(
320            model_class=model_class,
321            sparse_mapping=sparse_mapping,
322            ref_mapping=ref_mapping,
323            qconfig_dict=qconfig_dict,
324            fqn_to_check=fqn_to_check,
325            test_class=self,
326            test_scripting=True,
327        )
328
329
330if __name__ == "__main__":
331    run_tests()
332