xref: /aosp_15_r20/external/executorch/backends/apple/mps/test/test_mps_linear.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1#
2#  Copyright (c) 2024 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6import inspect
7
8import unittest
9
10from typing import Tuple
11
12import torch
13from executorch.backends.apple.mps.test.test_mps_utils import TestMPS
14
15
16class TestLinear(TestMPS):
17    @unittest.skip("Dynamic shapes not supported in MPS backend")
18    def test_fp16_linear(self):
19        for use_bias in (True, False):
20            for num_batch_dims in range(1, 3):
21                self._test_linear(
22                    lambda in_size, out_size: torch.nn.Linear(
23                        in_size, out_size, bias=use_bias  # noqa
24                    ),
25                    num_batch_dims=num_batch_dims,
26                    uses_bias=use_bias,
27                    dtype=torch.float16,
28                    atol=5e-2,
29                )
30
31    @unittest.skip("Dynamic shapes not supported in MPS backend")
32    def test_fp32_linear(self):
33        for use_bias in (True, False):
34            for num_batch_dims in range(1, 3):
35                self._test_linear(
36                    lambda in_size, out_size: torch.nn.Linear(
37                        in_size, out_size, bias=use_bias  # noqa
38                    ),
39                    uses_bias=use_bias,
40                    num_batch_dims=num_batch_dims,
41                )
42
43    @unittest.skip("Dynamic shapes not supported in MPS backend")
44    def test_qc8_linear(self):
45        for use_bias in (True, False):
46            for num_batch_dims in range(1, 3):
47                self._test_linear(
48                    lambda in_size, out_size: torch.nn.Linear(
49                        in_size, out_size, bias=use_bias  # noqa
50                    ),
51                    uses_bias=use_bias,
52                    quant_type="per_channel",
53                    num_batch_dims=num_batch_dims,
54                )
55
56    @unittest.skip("Dynamic shapes not supported in MPS backend")
57    def test_fp32_addmm(self):
58        """
59        Note that the ConvertToLinear pass requires the weight matrix to be transposed.
60        """
61
62        class AddMMModule(torch.nn.Module):
63            def __init__(self, in_size, out_size):
64                super().__init__()
65                self.mat = torch.nn.Parameter(torch.randn(in_size, out_size))
66                self.bias = torch.nn.Parameter(torch.randn(1, out_size))
67
68            def forward(self, x):
69                return torch.addmm(self.bias, x, self.mat)
70
71        self._test_linear(
72            lambda in_size, out_size: AddMMModule(in_size, out_size),
73            uses_bias=True,
74        )
75
76    @unittest.skip("Dynamic shapes not supported in MPS backend")
77    def test_fp32_linear_fused_relu(self):
78        class LinearReluModule(torch.nn.Module):
79            def __init__(self, in_size, out_size, use_bias):
80                super().__init__()
81                self.linear = torch.nn.Linear(in_size, out_size, bias=use_bias)
82
83            def forward(self, x):
84                return torch.nn.functional.relu(self.linear(x))
85
86        for use_bias in (True, False):
87            for num_batch_dims in range(1, 3):
88                self._test_linear(
89                    lambda in_size, out_size: LinearReluModule(
90                        in_size,
91                        out_size,
92                        use_bias,  # noqa
93                    ),
94                    uses_bias=use_bias,
95                    num_batch_dims=num_batch_dims,
96                )
97
98    @unittest.skip("Dynamic shapes not supported in MPS backend")
99    def test_qs8_linear_fused_relu(self):
100        class LinearReluModule(torch.nn.Module):
101            def __init__(self, in_size, out_size, use_bias):
102                super().__init__()
103                self.linear = torch.nn.Linear(in_size, out_size, bias=use_bias)
104
105            def forward(self, x):
106                return torch.nn.functional.relu(self.linear(x))
107
108        for use_bias in (True, False):
109            for num_batch_dims in range(1, 3):
110                self._test_linear(
111                    lambda in_size, out_size: LinearReluModule(
112                        in_size,
113                        out_size,
114                        use_bias,  # noqa
115                    ),
116                    num_batch_dims=num_batch_dims,
117                    uses_bias=use_bias,
118                    quant_type="per_tensor",
119                )
120
121    @unittest.skip("Dynamic shapes not supported in MPS backend")
122    def test_qs8_linear(self):
123        for use_bias in (True, False):
124            for num_batch_dims in range(1, 3):
125                self._test_linear(
126                    lambda in_size, out_size: torch.nn.Linear(
127                        in_size, out_size, bias=use_bias  # noqa
128                    ),
129                    uses_bias=use_bias,
130                    num_batch_dims=num_batch_dims,
131                    quant_type="per_tensor",
132                )
133
134    @unittest.skip(
135        "quantized_decomposed_dequantize_per_channel_default is not supported bt MPS delegate"
136    )
137    def test_qd8_fp32_per_token_weight_per_channel_int8(self):
138        self._run_manual_dqlinear_tests(8, torch.float)
139
140    @unittest.skip(
141        "quantized_decomposed_dequantize_per_channel_default is not supported bt MPS delegate"
142    )
143    def test_qd8_fp32_per_token_weight_per_channel_int4(self):
144        self._run_manual_dqlinear_tests(4, torch.float)
145
146    def test_qd8_fp32_per_token_weight_per_channel_group_int4(self):
147        M_sizes = [1]
148        K_sizes = [64]
149        bl_sizes = [64]
150        N_sizes = [32]
151
152        for use_bias in [True, False]:
153            for i, _ in enumerate(M_sizes):
154                M = int(M_sizes[i])
155                K = int(K_sizes[i])
156                N = int(N_sizes[i])
157                bl = int(bl_sizes[i])
158                mod = self.ManualDQLinear(
159                    input_channels=K,
160                    output_channels=N,
161                    weight_n_bit=4,
162                    dtype=torch.float,
163                    group_size=bl,
164                    force_groupwise_quant=True,
165                    use_bias=use_bias,
166                )
167
168                inputs = (torch.randn(1, M, K),)
169                self._test_manual_dq_linear(
170                    mod,
171                    inputs,
172                    weight_groupwise=True,
173                    use_bias=use_bias,
174                )
175
176    @unittest.skip("Need to fix the dq_per_channel_group output dtype")
177    def _test_qd8_fp16_per_token_weight_per_channel_group_int4(self):
178        M_sizes = [1, 2, 17, 31]
179        K_sizes = [8, 32, 64, 128]
180        bl_sizes = [8, 16, 16, 32]
181        N_sizes = [2, 17, 92, 128]
182
183        for use_bias in [True, False]:
184            for i, _ in enumerate(M_sizes):
185                M = int(M_sizes[i])
186                K = int(K_sizes[i])
187                N = int(N_sizes[i])
188                bl = int(bl_sizes[i])
189                mod = self.ManualDQLinear(
190                    input_channels=K,
191                    output_channels=N,
192                    weight_n_bit=4,
193                    dtype=torch.float16,
194                    group_size=bl,
195                    force_groupwise_quant=True,
196                    use_bias=use_bias,
197                )
198
199                inputs = (torch.randn(1, M, K, dtype=torch.float16),)
200                self._test_manual_dq_linear(
201                    mod,
202                    inputs,
203                    weight_groupwise=True,
204                    use_bias=use_bias,
205                    atol=0.1,
206                    rtol=0.1,
207                )
208
209    def _test_linear(
210        self,
211        make_module,
212        uses_bias,
213        num_batch_dims=1,
214        quant_type=None,
215        dtype: torch.dtype = torch.float,
216        atol=1e-03,
217    ):
218        in_sizes = [3, 4, 4]
219        input_sizes = [4, 37, 17]
220        output_sizes = [4, 17, 37]
221
222        for i, _ in enumerate(in_sizes):
223            in_size = int(in_sizes[i])
224            input_size = int(input_sizes[i])
225            output_size = int(output_sizes[i])
226            input_shape = [in_size] * num_batch_dims + [input_size]
227            print(f"Testing input_shape {input_shape} with {output_size} out_channels")
228
229            module = make_module(input_size, output_size).eval().to(dtype)
230            inputs = (torch.randn(input_shape).to(dtype),)
231            dynamic_shape = {}
232            for i in range(num_batch_dims):
233                dynamic_shape[i] = torch.export.Dim(f"batch{i}", min=2, max=in_size)
234
235            dynamic_shape = (dynamic_shape,)
236            print(dynamic_shape)
237            self.lower_and_test_without_partitioner(
238                module,
239                inputs,
240                func_name=inspect.stack()[0].function[5:],
241                dynamic_shapes=dynamic_shape,
242                atol=atol,
243                rtol=1e-03,
244            )
245
246    class ManualDQLinear(torch.nn.Module):
247        def __init__(
248            self,
249            input_channels: int = 4,
250            output_channels: int = 4,
251            dtype: torch.dtype = torch.float,
252            weight_n_bit: int = 4,
253            group_size: int = 0,
254            force_groupwise_quant: bool = False,
255            use_bias: bool = False,
256        ):
257            super().__init__()
258
259            self.ic = input_channels
260            self.oc = output_channels
261
262            assert dtype in [torch.float, torch.half], "Unsupported op dtype"
263            self.op_dtype = dtype
264
265            self.group_size = self.ic if group_size == 0 else group_size
266            self.num_groups = 1
267            if self.group_size != self.ic:
268                assert self.ic % self.group_size == 0
269                assert self.group_size % 8 == 0  # TODO make this 16
270                self.num_groups = self.ic // self.group_size
271
272            assert weight_n_bit in [4, 8], "Unsupported weight_n_bit"
273            self.w_n_bit = weight_n_bit
274            self.w_quant_min, self.w_quant_max = self.get_min_max(self.w_n_bit)
275
276            self.w = torch.nn.Parameter(
277                torch.randn(self.oc, self.ic), requires_grad=False
278            )
279            self.w_q = torch.nn.Parameter(
280                torch.zeros(self.oc, self.ic), requires_grad=False
281            )
282            # Quantize the weights as per folded setup
283            if self.group_size != self.ic or force_groupwise_quant:
284                self.w_scales = torch.nn.Parameter(
285                    torch.zeros(self.oc, self.num_groups), requires_grad=False
286                )
287                self.w_zero_points = torch.nn.Parameter(
288                    torch.zeros(self.oc, self.num_groups), requires_grad=False
289                )
290                self.quant_weight_per_channel_group()
291            else:  # per_channel quantization
292                self.w_scales = torch.nn.Parameter(
293                    torch.zeros(self.oc), requires_grad=False
294                )
295                self.w_zero_points = torch.nn.Parameter(
296                    torch.zeros(self.oc), requires_grad=False
297                )
298                self.quant_weight_per_channel()
299
300            self.bias = (
301                torch.nn.Parameter(
302                    torch.randn(self.oc).to(self.op_dtype), requires_grad=False
303                )
304                if use_bias
305                else None
306            )
307
308        def get_min_max(self, n_bit: int = 4):
309            max_int = 2 ** (n_bit - 1) - 1
310            min_int = -(2 ** (n_bit - 1))
311            return min_int, max_int
312
313        def get_channel_qparams_symmetric(
314            self,
315            w: torch.Tensor,
316            n_bit: int = 4,
317            precision: torch.dtype = torch.float32,
318        ):
319            assert w.dim() == 2
320
321            to_quant = w.to(precision)
322            assert torch.isnan(to_quant).sum() == 0
323
324            max_val = to_quant.amax(dim=1, keepdim=True)
325            min_val = to_quant.amin(dim=1, keepdim=True)
326            min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
327            max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
328
329            min_int, max_int = self.get_min_max(n_bit)
330
331            max_val_abs = torch.max(-min_val_neg, max_val_pos)
332            scales = max_val_abs / (float(max_int - min_int) / 2)
333            scales = torch.max(
334                scales, torch.full_like(scales, torch.finfo(torch.float32).eps)
335            )
336            zeros = torch.full_like(scales, 0)
337            return scales.to(precision).reshape(w.shape[0]), zeros.to(
338                precision
339            ).reshape(w.shape[0]).reshape(w.shape[0])
340
341        # Note: not using from torchao.quantization.quant_primitives because it will run into op registraion issues
342        def get_group_qparams_symmetric(
343            self, w, n_bit=4, groupsize=128, precision=torch.float32
344        ):
345            # needed for GPTQ with padding
346            if groupsize > w.shape[-1]:
347                groupsize = w.shape[-1]
348            assert groupsize > 1
349            assert w.shape[-1] % groupsize == 0
350            assert w.dim() == 2
351
352            to_quant = w.reshape(-1, groupsize)
353            assert torch.isnan(to_quant).sum() == 0
354
355            max_val = to_quant.amax(dim=1, keepdim=True)
356            min_val = to_quant.amin(dim=1, keepdim=True)
357            min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
358            max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
359
360            max_val_abs = torch.max(-min_val_neg, max_val_pos)
361            max_int = 2 ** (n_bit - 1) - 1
362            min_int = -(2 ** (n_bit - 1))
363
364            scales = max_val_abs / (float(max_int - min_int) / 2)
365            scales = torch.max(
366                scales, torch.full_like(scales, torch.finfo(torch.float32).eps)
367            )
368            # TODO: make sure abs(scales) is not too small?
369            zeros = torch.full_like(scales, 0)
370            return scales.to(precision).reshape(w.shape[0], -1), zeros.to(
371                precision
372            ).reshape(w.shape[0], -1)
373
374        # Note: not using from torchao.quantization.quant_primitives because it will run into op registraion issues
375        def group_quantize_tensor_symmetric(
376            self, w, n_bit=4, group_size=128, precision=torch.float32
377        ):
378            scales, zeros = self.get_group_qparams_symmetric(
379                w, n_bit, group_size, precision
380            )
381            n_bit = 4
382            max_int = 2 ** (n_bit - 1) - 1
383            min_int = -(2 ** (n_bit - 1))
384            # TODO: currently we don't know how to express torch.int4, we'll
385            # add torch.int4 to core later
386            w_int8 = torch.ops.quantized_decomposed.quantize_per_channel_group(
387                w, scales, zeros, min_int, max_int, torch.int8, group_size
388            )
389
390            return w_int8, scales, zeros
391
392        def fwd_input_per_token(self, input: torch.Tensor) -> torch.Tensor:
393            ip_quant_min = -128
394            ip_quant_max = 127
395            (
396                ip_scales,
397                ip_zero_points,
398            ) = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(
399                input, torch.int8
400            )
401
402            input = torch.ops.quantized_decomposed.quantize_per_token(
403                input,
404                ip_scales,
405                ip_zero_points,
406                ip_quant_min,
407                ip_quant_max,
408                torch.int8,
409            )
410            input = torch.ops.quantized_decomposed.dequantize_per_token(
411                input,
412                ip_scales,
413                ip_zero_points,
414                ip_quant_min,
415                ip_quant_max,
416                torch.int8,
417                self.op_dtype,
418            )
419            return input
420
421        def quant_weight_per_channel(self):
422            (
423                self.w_scales.data,
424                self.w_zero_points.data,
425            ) = self.get_channel_qparams_symmetric(
426                self.w, n_bit=self.w_n_bit, precision=self.op_dtype
427            )
428            self.w_q.data = torch.ops.quantized_decomposed.quantize_per_channel(
429                self.w,
430                self.w_scales,
431                self.w_zero_points,
432                axis=0,
433                quant_min=self.w_quant_min,
434                quant_max=self.w_quant_max,
435                dtype=torch.int8,
436            )
437
438        def quant_weight_per_channel_group(self):
439            self.w_q.data, w, zp = self.group_quantize_tensor_symmetric(
440                self.w,
441                n_bit=self.w_n_bit,
442                group_size=self.group_size,
443            )
444            expected_min, expected_max = self.get_min_max(self.w_n_bit)
445            assert (
446                torch.min(self.w_q.data) >= expected_min
447            ), "Found smaller than min element in quantized weight tensor"
448            assert (
449                torch.max(self.w_q.data) <= expected_max
450            ), "Found larger than max element in quantized weight tensor"
451            assert (
452                w.ndim == 2 and zp.ndim == 2
453            ), f"Expecting 2d scales and zp tensors, but got {w.shape}, {zp.shape}"
454            self.w_scales.data, self.w_zero_points.data = w, zp
455
456        def fwd_weight_per_channel(self) -> torch.Tensor:
457            # This is HACKY because the dequant will produce fp32
458            return torch.ops.quantized_decomposed.dequantize_per_channel(
459                self.w_q,
460                self.w_scales,
461                self.w_zero_points,
462                axis=0,
463                quant_min=self.w_quant_min,
464                quant_max=self.w_quant_max,
465                dtype=torch.int8,  # Regardless of w_n_bit, convert to 4b later
466            )
467
468        def fwd_weight_per_channel_group(self) -> torch.Tensor:
469            return torch.ops.quantized_decomposed.dequantize_per_channel_group(
470                self.w_q,
471                self.w_scales,
472                self.w_zero_points,
473                self.w_quant_min,
474                self.w_quant_max,
475                dtype=torch.int8,  # Regardless of w_n_bit, convert to 4b later
476                group_size=self.group_size,
477                output_dtype=self.op_dtype,
478            )
479
480        def forward(self, input: torch.Tensor) -> torch.Tensor:
481            # Input
482            input = self.fwd_input_per_token(input)
483
484            # Weights
485            w = (
486                self.fwd_weight_per_channel_group()
487                if self.w_scales.ndim == 2
488                else self.fwd_weight_per_channel()
489            )
490            assert isinstance(w, torch.Tensor)
491            return torch.nn.functional.linear(input, w, self.bias)
492
493    def _test_manual_dq_linear(
494        self,
495        mod: torch.nn.Module,
496        inputs: Tuple[torch.Tensor],
497        weight_groupwise: bool = False,
498        use_bias: bool = False,
499    ):
500        self.lower_and_test_without_partitioner(
501            mod, inputs, func_name=inspect.stack()[0].function[5:]
502        )
503
504    def _run_manual_dqlinear_tests(self, weight_n_bit: int, op_dtype: torch.dtype):
505        in_sizes = [1, 4, 4]
506        input_sizes = [4, 37, 17]
507        output_sizes = [4, 17, 37]
508
509        for use_bias in [True, False]:
510            for i, _ in enumerate(in_sizes):
511                in_size = int(in_sizes[i])
512                input_size = int(input_sizes[i])
513                output_size = int(output_sizes[i])
514                mod = self.ManualDQLinear(
515                    input_channels=input_size,
516                    output_channels=output_size,
517                    weight_n_bit=weight_n_bit,
518                    dtype=op_dtype,
519                    use_bias=use_bias,
520                )
521
522                inputs = (torch.randn(1, in_size, input_size).to(op_dtype),)
523                self._test_manual_dq_linear(mod, inputs, use_bias=use_bias)
524