xref: /aosp_15_r20/external/pytorch/test/inductor/test_mkldnn_pattern_matcher.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: cpu inductor"]
2import contextlib
3import copy
4import itertools
5import unittest
6
7import torch
8import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
9
10from torch._dynamo import config as dynamo_config
11from torch._dynamo.utils import counters
12from torch._export import capture_pre_autograd_graph
13from torch._inductor import config, metrics
14from torch._inductor.test_case import run_tests, TestCase
15from torch._inductor.utils import run_and_get_code
16from torch.ao.quantization.quantize_pt2e import (
17    convert_pt2e,
18    prepare_pt2e,
19    prepare_qat_pt2e,
20)
21from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
22from torch.nn import functional as F
23from torch.testing._internal.common_quantization import (
24    skipIfNoDynamoSupport,
25    skipIfNoONEDNN,
26    skipIfNoONEDNNBF16,
27)
28from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm, TEST_MKL
29from torch.testing._internal.inductor_utils import _check_has_dynamic_shape, HAS_CPU
30
31
32# The dict value is match_nodes(computation_op+unary_op)
33
34unary_list = {
35    torch.nn.ReLU(): 2,
36    torch.nn.Sigmoid(): 2,
37    torch.nn.Tanh(): 2,
38    torch.nn.Hardswish(): 6,
39    torch.nn.LeakyReLU(0.1, inplace=False): 4,
40    torch.nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False): 3,
41    torch.nn.Hardtanh(min_val=-0.5, max_val=float("inf"), inplace=False): 3,
42    torch.nn.GELU(approximate="none"): 6,
43    torch.nn.GELU(approximate="tanh"): 10,
44    torch.nn.ReLU6(): 3,
45    torch.nn.SiLU(): 3,
46    torch.nn.Hardsigmoid(): 5,
47}
48
49non_decomposed_unary_list = [
50    torch.nn.ReLU,
51    torch.nn.Sigmoid,
52    torch.nn.Tanh,
53]
54
55# The dict value is (match_count, match_nodes, inplace)
56binary_list = {
57    lambda x, y: torch.add(x, y): (1, 2, False),  # call_function
58    lambda x, y: torch.add(y, x): (1, 2, False),  # call_function
59    lambda x, y: x.add(y): (1, 2, False),  # call_method
60    lambda x, y: x.add_(y): (1, 2, True),  # call_method
61    lambda x, y: torch.sub(x, y): (1, 2, False),  # call_function
62    lambda x, y: x.sub(y): (1, 2, False),  # call_method
63    lambda x, y: x.sub_(y): (1, 2, True),  # call_method
64}
65
66quantization_add_fn_list = [
67    lambda x, y: torch.add(x, y),
68    lambda x, y: x.add(y),
69]
70
71quantization_inplace_add_fn_list = [
72    lambda x, y: x.add_(y),
73]
74
75
76def get_default_quantizer(is_qat, is_dynamic):
77    quantizer = X86InductorQuantizer()
78    quantizer.set_global(
79        xiq.get_default_x86_inductor_quantization_config(
80            is_qat=is_qat, is_dynamic=is_dynamic
81        )
82    )
83    return quantizer
84
85
86def cal_conv_generated_kernel_number(mod, input, dtype):
87    # this function is to decide how many kernels are generated
88    # while testing conv2d/3d/deconv2d
89    # the assumption is:
90    #   (1) There will be a to_dtype kernel for input for lp
91    #   (2) inductor always use channe_last format, there will
92    #       be a to_channel_last format for input
93    #   (3) to_dtype and to_channel_last for input can be fused
94    #   (4) inductor always get channel last format from mkldnn_conv_pointwise(binary),
95    #       and force the output to have same stride with eager.
96    #       So there will be a to_contiguous for output if eager output is contiguouse
97    mod = copy.deepcopy(mod)
98    input = input.clone()
99    if dtype == torch.float32:
100        maybe_autocast = contextlib.nullcontext()
101    else:
102        maybe_autocast = torch.cpu.amp.autocast(dtype=dtype)
103    with torch.no_grad(), maybe_autocast:
104        output = mod(input)
105    input_kernel, output_kernel = 0, 0
106    if (
107        input.is_contiguous(memory_format=torch.contiguous_format)
108        or dtype != torch.float32
109    ):
110        input_kernel = 1
111    if output.is_contiguous(memory_format=torch.contiguous_format):
112        output_kernel = 1
113    return input_kernel + output_kernel
114
115
116@config.patch({"freezing": True})
117class TestPatternMatcherBase(TestCase):
118    def _check_unary_is_decomposed(self, unary_fn):
119        return not any(
120            isinstance(unary_fn, fn)
121            for fn in [torch.nn.ReLU, torch.nn.Sigmoid, torch.nn.Tanh]
122        )
123
124    def _clone_inputs(self, inputs):
125        def clone(x):
126            if not isinstance(x, torch.Tensor):
127                return x
128            return x.clone()
129
130        return tuple(clone(x) for x in inputs)
131
132    def _generate_qdq_quantized_model(
133        self, mod, inputs, is_qat=False, is_dynamic=False, quantizer=None
134    ):
135        maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad()
136        with maybe_no_grad:
137            export_model = capture_pre_autograd_graph(
138                mod,
139                inputs,
140            )
141            quantizer = (
142                quantizer if quantizer else get_default_quantizer(is_qat, is_dynamic)
143            )
144            prepare_model = (
145                prepare_qat_pt2e(export_model, quantizer)
146                if is_qat
147                else prepare_pt2e(export_model, quantizer)
148            )
149            prepare_model(*inputs)
150            convert_model = convert_pt2e(prepare_model)
151            torch.ao.quantization.move_exported_model_to_eval(convert_model)
152            return convert_model
153
154    def _test_common(
155        self,
156        mod,
157        inputs,
158        matcher_count=None,
159        matcher_nodes=None,
160        atol=1e-5,
161        rtol=1.3e-6,
162        check_autocast=torch.float32,
163        check_quantization=False,
164        is_qat=False,
165        matcher_check_fn=None,
166        dtype=None,
167        is_dynamic=False,
168        quantizer=None,
169    ):
170        counters.clear()
171        torch._dynamo.reset()
172        assert matcher_check_fn is not None or (
173            matcher_count is not None and matcher_nodes is not None
174        )
175        if (
176            check_autocast == torch.bfloat16
177            and torch.ops.mkldnn._is_mkldnn_bf16_supported()
178        ):
179            maybe_autocast = torch.cpu.amp.autocast(dtype=torch.bfloat16)
180            atol, rtol = 1e-2, 1e-2
181        elif (
182            check_autocast == torch.float16
183            and torch.ops.mkldnn._is_mkldnn_fp16_supported()
184        ):
185            maybe_autocast = torch.cpu.amp.autocast(dtype=torch.float16)
186            atol, rtol = 1e-2, 1e-2
187        else:
188            assert check_autocast == torch.float32
189            maybe_autocast = contextlib.nullcontext()
190
191        if check_quantization:
192            convert_model = self._generate_qdq_quantized_model(
193                mod, inputs, is_qat, is_dynamic, quantizer
194            )
195            with torch.no_grad(), maybe_autocast:
196                _ = torch.compile(convert_model)(*inputs)
197                if matcher_count is not None:
198                    self.assertEqual(
199                        counters["inductor"]["pattern_matcher_count"], matcher_count
200                    )
201                if matcher_nodes is not None:
202                    self.assertEqual(
203                        counters["inductor"]["pattern_matcher_nodes"],
204                        matcher_nodes,
205                    )
206                if matcher_check_fn is not None:
207                    matcher_check_fn()
208        else:
209            with torch.no_grad(), maybe_autocast:
210                clone_inputs = self._clone_inputs(inputs)
211                expected = mod(*inputs)
212                actual = torch.compile(mod)(*clone_inputs)
213                torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
214                if matcher_count is not None:
215                    self.assertEqual(
216                        counters["inductor"]["pattern_matcher_count"], matcher_count
217                    )
218                if matcher_nodes is not None:
219                    self.assertEqual(
220                        counters["inductor"]["pattern_matcher_nodes"],
221                        matcher_nodes,
222                    )
223                if matcher_check_fn is not None:
224                    matcher_check_fn()
225
226    def _test_code_common(
227        self,
228        mod,
229        inputs,
230        include_ops,
231        exclude_ops,
232        atol=1e-5,
233        rtol=1.3e-6,
234        check_quantization=False,
235        check_dynamic=None,
236        num_include_ops=None,
237    ):
238        with torch.no_grad():
239            clone_inputs = self._clone_inputs(inputs)
240            if check_quantization:
241                mod = self._generate_qdq_quantized_model(mod, inputs)
242            expected = mod(*inputs)
243            actual, (source_code,) = run_and_get_code(
244                torch.compile(mod, fullgraph=True, dynamic=check_dynamic),
245                *clone_inputs,
246            )
247            for op in include_ops:
248                self.assertIn(op, source_code)
249            if num_include_ops is not None:
250                assert len(include_ops) == len(num_include_ops)
251                for i in range(len(include_ops)):
252                    self.assertEqual(
253                        source_code.count(include_ops[i]), num_include_ops[i]
254                    )
255            for op in exclude_ops:
256                self.assertNotIn(op, source_code)
257            if check_dynamic is not None:
258                _check_has_dynamic_shape(self, source_code)
259            if not check_quantization:
260                # Skip due to reduce range setting for Quantization on preCI system.
261                torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
262
263
264class TestPatternMatcher(TestPatternMatcherBase):
265    def _test_conv_unary_cpu_base(self, dim=4):
266        assert dim == 4 or dim == 5
267
268        class M(torch.nn.Module):
269            def __init__(
270                self,
271                unary_fn,
272                **kwargs,
273            ):
274                super().__init__()
275                if dim == 4:
276                    self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1)
277                else:
278                    self.conv = torch.nn.Conv3d(3, 16, kernel_size=3, stride=1)
279                self.unary_fn = unary_fn
280
281            def forward(self, x):
282                x = self.conv(x)
283                return self.unary_fn(x)
284
285        dtypes = [
286            torch.float,
287        ]
288        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
289            dtypes.append(torch.bfloat16)
290        if torch.ops.mkldnn._is_mkldnn_fp16_supported():
291            dtypes.append(torch.float16)
292        cl_format = torch.channels_last if dim == 4 else torch.channels_last_3d
293        options = itertools.product(
294            unary_list.keys(),
295            [torch.contiguous_format, cl_format],
296            dtypes,
297        )
298
299        for (
300            unary_fn,
301            memory_format,
302            dtype,
303        ) in options:
304            metrics.reset()
305            if dim == 4:
306                x_shape = (1, 3, 56, 56)
307            else:
308                x_shape = (1, 3, 20, 56, 56)
309            mod = M(unary_fn).to(memory_format=memory_format).eval()
310
311            v = (
312                torch.randn(x_shape, dtype=torch.float32)
313                .add(1)
314                .to(memory_format=memory_format)
315            )
316            # Add 1 for weight packing pass.
317            match_nodes = unary_list[unary_fn] + 1
318            if dtype in (
319                torch.float16,
320                torch.bfloat16,
321            ) and self._check_unary_is_decomposed(unary_fn):
322                # Has extra dtype conversion nodes for autocast.
323                match_nodes += 2
324            self._test_common(mod, (v,), 2, match_nodes, check_autocast=dtype)
325            generated_kernel_count = cal_conv_generated_kernel_number(mod, v, dtype)
326            self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
327
328    @skipIfNoDynamoSupport
329    @skipIfNoONEDNN
330    @skipIfRocm
331    def test_conv2d_unary_cpu(self):
332        self._test_conv_unary_cpu_base(dim=4)
333
334    @skipIfNoDynamoSupport
335    @skipIfNoONEDNN
336    @skipIfRocm
337    def test_conv3d_unary_cpu(self):
338        self._test_conv_unary_cpu_base(dim=5)
339
340    def test_linear_unary(self):
341        class M(torch.nn.Module):
342            def __init__(
343                self,
344                unary_fn,
345                in_features,
346                out_features,
347                bias,
348                **kwargs,
349            ):
350                super().__init__()
351                self.linear = torch.nn.Linear(
352                    in_features,
353                    out_features,
354                    bias,
355                    **kwargs,
356                )
357                self.unary_fn = unary_fn
358
359            def forward(self, x):
360                x = self.linear(x)
361                return self.unary_fn(x)
362
363        dtypes = []
364        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
365            dtypes.append(torch.bfloat16)
366        if torch.ops.mkldnn._is_mkldnn_fp16_supported():
367            dtypes.append(torch.float16)
368        options = itertools.product(unary_list, [True, False], dtypes)
369        for unary_fn, bias, dtype in options:
370            metrics.reset()
371            mod = M(unary_fn, 10, 30, bias=bias).eval()
372            # only fuse for linear when the dtype is bf16
373            mod = mod
374            v = torch.randn(2, 10)
375            # packing pass + unary fusion.
376            matcher_count = 2
377            # Add 1 for weight packing pass.
378            matcher_nodes = unary_list[unary_fn] + 1
379            if self._check_unary_is_decomposed(unary_fn):
380                # Has extra dtype conversion nodes for autocast.
381                matcher_nodes += 2
382            self._test_common(
383                mod, (v,), matcher_count, matcher_nodes, check_autocast=dtype
384            )
385            # only generated 1 kernel for "to"
386            self.assertEqual(metrics.generated_kernel_count, 1)
387
388    @unittest.skipIf(not TEST_MKL, "Test requires MKL")
389    def test_linear_fp32(self):
390        class M(torch.nn.Module):
391            def __init__(self, bias):
392                super().__init__()
393                self.linear = torch.nn.Linear(10, 30, bias)
394
395            def forward(self, x):
396                return self.linear(x)
397
398        for bias in [True, False]:
399            mod = M(bias=bias).eval()
400            v = torch.randn(2, 10)
401            # packing pass.
402            matcher_count = 1
403            matcher_nodes = 1
404            self._test_common(mod, (v,), matcher_count, matcher_nodes)
405
406    def test_linear_add_bias(self):
407        class M(torch.nn.Module):
408            def __init__(self, dtype, unary_fn):
409                super().__init__()
410                self.linear1 = torch.nn.Linear(10, 64, bias=False)
411                self.bias1 = torch.randn(64).to(dtype=dtype)
412                self.linear2 = torch.nn.Linear(10, 64, bias=False)
413                self.bias2 = torch.randn(64).to(dtype=dtype)
414                self.unary_fn = unary_fn
415
416            def forward(self, x):
417                a = self.linear1(x) + self.bias1
418                b = self.linear2(x) + self.bias2
419                return self.unary_fn(a), self.unary_fn(b)
420
421        dtypes = []
422        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
423            dtypes.append(torch.bfloat16)
424        if torch.ops.mkldnn._is_mkldnn_fp16_supported():
425            dtypes.append(torch.float16)
426        options = itertools.product(unary_list, dtypes)
427        for unary_fn, dtype in options:
428            metrics.reset()
429            mod = M(dtype, unary_fn).eval()
430            v = torch.randn(2, 10)
431            matcher_count = 3
432            # Add 1 for weight packing pass, add 2 for bias folding pass per linear.
433            matcher_nodes = unary_list[unary_fn] + 3
434            if self._check_unary_is_decomposed(unary_fn):
435                # Has extra dtype conversion nodes for autocast.
436                matcher_nodes += 2
437            # we have 2 linears, so we double the matcher_count/nodes
438            self._test_common(
439                mod, (v,), matcher_count * 2, matcher_nodes * 2, check_autocast=dtype
440            )
441            self.assertEqual(metrics.generated_kernel_count, 1)
442
443    @skipIfNoDynamoSupport
444    @skipIfNoONEDNN
445    @skipIfRocm
446    def test_conv_transpose2d_unary(self):
447        class M(torch.nn.Module):
448            def __init__(
449                self,
450                unary_fn,
451                **kwargs,
452            ):
453                super().__init__()
454                self.conv_transpose2d = torch.nn.ConvTranspose2d(
455                    3, 16, 3, stride=2, padding=1
456                )
457                self.unary_fn = unary_fn
458
459            def forward(self, x):
460                x = self.conv_transpose2d(x)
461                return self.unary_fn(x)
462
463        dtypes = [
464            torch.float,
465        ]
466        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
467            dtypes.append(torch.bfloat16)
468        if torch.ops.mkldnn._is_mkldnn_fp16_supported():
469            dtypes.append(torch.float16)
470
471        options = itertools.product(
472            unary_list,
473            [torch.contiguous_format, torch.channels_last],
474            dtypes,
475        )
476
477        for unary_fn, memory_format, dtype in options:
478            metrics.reset()
479            x_shape = (1, 3, 28, 28)
480            mod = M(unary_fn).eval()
481
482            v = torch.randn(x_shape, dtype=torch.float32).to(
483                memory_format=memory_format
484            )
485            # Add 1 for weight packing pass.
486            match_nodes = unary_list[unary_fn] + 1
487            if dtype in (
488                torch.float16,
489                torch.bfloat16,
490            ) and self._check_unary_is_decomposed(unary_fn):
491                # Has extra dtype conversion nodes for autocast.
492                match_nodes += 2
493            self._test_common(mod, (v,), 2, match_nodes, check_autocast=dtype)
494            generated_kernel_count = cal_conv_generated_kernel_number(mod, v, dtype)
495            self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
496
497    def _test_conv_binary_base(self, dim=4):
498        assert dim == 4 or dim == 5
499
500        class M(torch.nn.Module):
501            def __init__(
502                self,
503                binary_fn,
504                has_relu,
505                **kwargs,
506            ):
507                super().__init__()
508                if dim == 4:
509                    self.conv1 = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1)
510                    self.conv2 = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1)
511                else:
512                    self.conv1 = torch.nn.Conv3d(3, 16, kernel_size=3, stride=1)
513                    self.conv2 = torch.nn.Conv3d(3, 16, kernel_size=3, stride=1)
514                self.binary_fn = binary_fn
515                self.has_relu = has_relu
516
517            def forward(self, x):
518                x1 = self.conv1(x)
519                x2 = self.conv2(x)
520                if has_relu:
521                    return self.binary_fn(x1, x2).relu()
522                else:
523                    return self.binary_fn(x1, x2)
524
525        dtypes = [
526            torch.float,
527        ]
528        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
529            dtypes.append(torch.bfloat16)
530        if torch.ops.mkldnn._is_mkldnn_fp16_supported():
531            dtypes.append(torch.float16)
532        cl_format = torch.channels_last if dim == 4 else torch.channels_last_3d
533        test_memory_format = [torch.contiguous_format, cl_format]
534        options = itertools.product(
535            binary_list,
536            [True, False],
537            test_memory_format,
538            dtypes,
539        )
540
541        for (
542            binary_fn,
543            has_relu,
544            memory_format,
545            dtype,
546        ) in options:
547            metrics.reset()
548            if dim == 4:
549                x_shape = (1, 3, 56, 56)
550            else:
551                x_shape = (1, 3, 20, 56, 56)
552            mod = M(binary_fn, has_relu).eval()
553            v = (
554                torch.randn(x_shape, dtype=torch.float32, requires_grad=True)
555                .add(1)
556                .to(memory_format=memory_format)
557            )
558            match_count = binary_list[binary_fn][0] + 2
559            match_nodes = binary_list[binary_fn][1]
560            if has_relu:
561                match_nodes += 1
562            self._test_common(
563                mod, (v,), match_count, match_nodes + 2, check_autocast=dtype
564            )
565            generated_kernel_count = cal_conv_generated_kernel_number(mod, v, dtype)
566            self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
567
568    @skipIfNoDynamoSupport
569    @skipIfNoONEDNN
570    @skipIfRocm
571    def test_conv2d_binary(self):
572        self._test_conv_binary_base(dim=4)
573
574    @skipIfNoDynamoSupport
575    @skipIfNoONEDNN
576    @skipIfRocm
577    def test_conv3d_binary(self):
578        self._test_conv_binary_base(dim=5)
579
580    def test_linear_binary(self):
581        class M(torch.nn.Module):
582            def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs):
583                super().__init__()
584                self.linear = torch.nn.Linear(
585                    in_channels, out_channels, bias=bias, **kwargs
586                )
587                self.binary_fn = binary_fn
588
589            def forward(self, x, y):
590                x = self.linear(x)
591                x = self.binary_fn(x, y.clone())
592                return x
593
594        dtypes = []
595        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
596            dtypes.append(torch.bfloat16)
597        if torch.ops.mkldnn._is_mkldnn_fp16_supported():
598            dtypes.append(torch.float16)
599        options = itertools.product(
600            binary_list, [[2, 3, 10], [2, 10]], [True, False], dtypes
601        )
602        out_feature = 30
603        for binary_fn, input_shape, bias, dtype in options:
604            metrics.reset()
605            # addmm(mm) + (linear+add)
606            match_count = 2
607            match_nodes = 3
608            if len(input_shape) == 3:
609                is_inplace = binary_list[binary_fn][2]
610                # view + linear + view(joint_graph+freeze pass)
611                match_count = match_count + 5 if is_inplace else match_count + 3
612                match_nodes = match_nodes + 7 if is_inplace else match_nodes + 5
613            mod = M(binary_fn, input_shape[-1], out_feature, bias).eval()
614            v = torch.randn(input_shape)
615            other = torch.randn(input_shape[:-1] + [out_feature]).to(dtype)
616            self._test_common(
617                mod,
618                (
619                    v,
620                    other,
621                ),
622                match_count,
623                match_nodes,
624                check_autocast=dtype,
625            )
626            self.assertEqual(metrics.generated_kernel_count, 1)
627
628    def test_multi_linear_share_same_input(self):
629        # llama pattern.
630        class M(torch.nn.Module):
631            def __init__(
632                self,
633            ):
634                super().__init__()
635                self.w1 = torch.nn.Linear(16, 16, bias=False)
636                self.w2 = torch.nn.Linear(16, 16, bias=False)
637
638            def forward(self, x):
639                return F.silu(self.w1(x)) * F.relu(self.w2(x))
640
641        dtypes = []
642        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
643            dtypes.append(torch.bfloat16)
644        if torch.ops.mkldnn._is_mkldnn_fp16_supported():
645            dtypes.append(torch.float16)
646        for dtype in dtypes:
647            mod = M().to(dtype).eval()
648            v = torch.randn(2, 4, 16).to(dtype)
649            # 1. view(match_count=4, match_nodes=4).
650            # 2. mm to packed linear(match_count=2, match_nodes=2).
651            # 3. view+linear+view to linear(match_count=2, match_nodes=6).
652            # 4. linear+silu fusion(match_count=1, match_nodes=5)
653            # 5. linear+relu fusion(match_count=1, match_nodes=2)
654
655            match_count = 10
656            match_nodes = 19
657            self._test_common(mod, (v,), match_count, match_nodes, rtol=1e-2, atol=1e-2)
658
659    def _qconv2d_cpu_test_helper(self, int8_mixed_bf16=False):
660        class M(torch.nn.Module):
661            def __init__(
662                self,
663                **kwargs,
664            ):
665                super().__init__()
666                self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1)
667                self.conv2 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1)
668
669            def forward(self, x):
670                return self.conv2(self.conv(x))
671
672        mod = M().eval()
673        v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1)
674
675        def matcher_check_fn():
676            # 1. Dequant-Conv2D pattern matched in QConv2D weight prepack * 1
677            #    int8_mixed_fp32: [dequant_node, dequantize_per_channel, clone, convolution]
678            #    int8_mixed_bf16: [dequant_node, optional(convert_element_type_4),
679            #     dequantize_per_channel, optional(convert_element_type_3), clone, convolution]
680            self.assertEqual(
681                counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
682            )
683            self.assertEqual(
684                counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"],
685                12 if int8_mixed_bf16 else 8,
686            )
687
688        self._test_common(
689            mod,
690            (v,),
691            check_quantization=True,
692            check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
693            matcher_check_fn=matcher_check_fn,
694        )
695
696    @skipIfNoDynamoSupport
697    @skipIfNoONEDNN
698    @skipIfRocm
699    def test_qconv2d_cpu(self):
700        r"""
701        This testcase will quantize a single Conv2d module.
702        """
703        self._qconv2d_cpu_test_helper()
704
705    @skipIfNoDynamoSupport
706    @skipIfNoONEDNNBF16
707    @skipIfNoONEDNN
708    @skipIfRocm
709    def test_qconv2d_int8_mixed_bf16(self):
710        r"""
711        This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization.
712        """
713        self._qconv2d_cpu_test_helper(int8_mixed_bf16=True)
714
715    def _qconv2d_unary_cpu_test_helper(
716        self,
717        int8_mixed_bf16=False,
718        unary_op=torch.nn.ReLU(),
719        qconv2d_unary_matcher_nodes=None,
720    ):
721        class M(torch.nn.Module):
722            def __init__(
723                self,
724                **kwargs,
725            ):
726                super().__init__()
727                self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1)
728                self.unary_fn = copy.deepcopy(unary_op)
729                self.conv2 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1)
730                self.unary_fn2 = copy.deepcopy(unary_op)
731
732            def forward(self, x):
733                tmp = self.unary_fn(self.conv(x))
734                return self.unary_fn2(self.conv2(tmp))
735
736        mod = M().eval()
737        v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1)
738
739        def matcher_check_fn():
740            # 1. Dequant-Conv2D pattern matched in quantization weight prepack * 2
741            self.assertEqual(
742                counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
743            )
744            # 2. QConv2D Unary fusion in post-grad fusion pass * 2
745            self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 2)
746            if qconv2d_unary_matcher_nodes:
747                self.assertEqual(
748                    counters["inductor"]["qconv2d_unary_matcher_nodes"],
749                    qconv2d_unary_matcher_nodes,
750                )
751
752        self._test_common(
753            mod,
754            (v,),
755            check_quantization=True,
756            check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
757            matcher_check_fn=matcher_check_fn,
758        )
759
760    @skipIfNoDynamoSupport
761    @skipIfNoONEDNN
762    @skipIfRocm
763    def test_qconv2d_relu_cpu(self):
764        r"""
765        This testcase will quantize Conv2d->ReLU pattern.
766        """
767        self._qconv2d_unary_cpu_test_helper()
768
769    @skipIfNoDynamoSupport
770    @skipIfNoONEDNNBF16
771    @skipIfNoONEDNN
772    @skipIfRocm
773    def test_qconv2d_relu_int8_mixed_bf16(self):
774        r"""
775        This testcase will quantize Conv2d->ReLU pattern with int8_mixed_bf16 quantization.
776        """
777        self._qconv2d_unary_cpu_test_helper(int8_mixed_bf16=True)
778
779    @skipIfNoDynamoSupport
780    @skipIfNoONEDNN
781    @skipIfRocm
782    def test_qconv2d_relu6_cpu(self):
783        r"""
784        This testcase will quantize Conv2d->ReLU6 pattern.
785        """
786        self._qconv2d_unary_cpu_test_helper(unary_op=torch.nn.ReLU6())
787
788    @skipIfNoDynamoSupport
789    @skipIfNoONEDNN
790    @skipIfRocm
791    def test_qconv2d_hardtanh_cpu(self):
792        r"""
793        This testcase will quantize Conv2d->Hardtanh pattern.
794        """
795        self._qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardtanh())
796
797    @skipIfNoDynamoSupport
798    @skipIfNoONEDNNBF16
799    @skipIfNoONEDNN
800    @skipIfRocm
801    def test_qconv2d_hardtanh_int8_mixed_bf16_cpu(self):
802        r"""
803        This testcase will quantize Conv2d->Hardtanh pattern.
804        Match.nodes:
805            [qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor]
806            [qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type]
807        """
808        self._qconv2d_unary_cpu_test_helper(
809            unary_op=torch.nn.Hardtanh(),
810            int8_mixed_bf16=True,
811            qconv2d_unary_matcher_nodes=11,
812        )
813
814    @skipIfNoDynamoSupport
815    @skipIfNoONEDNN
816    @skipIfRocm
817    def test_qconv2d_hardswish_cpu(self):
818        r"""
819        This testcase will quantize Conv2d->Hardswish pattern.
820        """
821        self._qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardswish())
822
823    @skipIfNoDynamoSupport
824    @skipIfNoONEDNNBF16
825    @skipIfNoONEDNN
826    @skipIfRocm
827    def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self):
828        r"""
829        This testcase will quantize Conv2d->Hardswish pattern.
830        Match.nodes:
831            [qconv2d_pointwise_default, convert_element_type, add, clamp_min,
832             clamp_max, mul, div, convert_element_type, quantize_per_tensor]
833            [qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type]
834        """
835        self._qconv2d_unary_cpu_test_helper(
836            unary_op=torch.nn.Hardswish(),
837            int8_mixed_bf16=True,
838            qconv2d_unary_matcher_nodes=17,
839        )
840
841    @skipIfNoDynamoSupport
842    @skipIfNoONEDNN
843    @skipIfRocm
844    def test_qconv2d_silu_cpu(self):
845        r"""
846        This testcase will quantize Conv2d->SiLU pattern.
847        """
848        self._qconv2d_unary_cpu_test_helper(unary_op=torch.nn.SiLU())
849
850    @skipIfNoDynamoSupport
851    @skipIfNoONEDNNBF16
852    @skipIfNoONEDNN
853    @skipIfRocm
854    def test_qconv2d_silu_int8_mixed_bf16_cpu(self):
855        r"""
856        This testcase will quantize Conv2d->SiLU pattern.
857        Match.nodes:
858            [qconv2d_pointwise_default, convert_element_type, sigmoid, mul,
859             convert_element_type, quantize_per_tensor]
860            [qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type]
861        """
862        self._qconv2d_unary_cpu_test_helper(
863            unary_op=torch.nn.SiLU(),
864            int8_mixed_bf16=True,
865            qconv2d_unary_matcher_nodes=11,
866        )
867
868    def _qconv2d_add_cpu_test_helper(self, use_relu=False, int8_mixed_bf16=False):
869        r"""
870        This testcase will quantize a Conv2d->Add pattern as:
871                 X
872               /   \
873        Conv1(X)   Conv2(X)
874               \   /
875                Add
876                 |
877           Optional(relu)
878                 |
879                 Y
880        """
881
882        class M(torch.nn.Module):
883            def __init__(
884                self,
885                add_fn,
886                use_relu,
887                **kwargs,
888            ):
889                super().__init__()
890                self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
891                self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
892                self.add_fn = add_fn
893                self.relu = torch.nn.ReLU()
894                self.conv3 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1)
895                self.conv4 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1)
896                self.add_fn2 = add_fn
897                self.relu2 = torch.nn.ReLU()
898                self.use_relu = use_relu
899
900            def forward(self, x):
901                x1 = self.conv1(x)
902                x2 = self.conv2(x)
903                tmp = self.add_fn(x1, x2)
904                if self.use_relu:
905                    tmp = self.relu(tmp)
906                tmp1 = self.conv3(tmp)
907                tmp2 = self.conv4(tmp)
908                res = self.add_fn2(tmp1, tmp2)
909                if self.use_relu:
910                    res = self.relu2(res)
911                return res
912
913        for add_fn in quantization_add_fn_list + quantization_inplace_add_fn_list:
914            mod = M(add_fn, use_relu).eval()
915            v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(
916                1
917            )
918
919            def matcher_check_fn():
920                # 1. Dequant-Conv2D pattern matched in quantization weight prepack * 4
921                self.assertEqual(
922                    counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 4
923                )
924                # 2. Qconv2d Binary Unary fusion in post-grad fusion pass * 2
925                self.assertEqual(
926                    counters["inductor"]["qconv2d_binary_matcher_count"], 2
927                )
928
929            self._test_common(
930                mod,
931                (v,),
932                check_quantization=True,
933                check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
934                matcher_check_fn=matcher_check_fn,
935            )
936
937    @skipIfNoDynamoSupport
938    @skipIfNoONEDNN
939    @skipIfRocm
940    def test_qconv2d_add_cpu(self):
941        self._qconv2d_add_cpu_test_helper()
942
943    @skipIfNoDynamoSupport
944    @skipIfNoONEDNNBF16
945    @skipIfNoONEDNN
946    @skipIfRocm
947    def test_qconv2d_add_int8_mixed_bf16(self):
948        self._qconv2d_add_cpu_test_helper(int8_mixed_bf16=True)
949
950    @skipIfNoDynamoSupport
951    @skipIfNoONEDNN
952    @skipIfRocm
953    def test_qconv2d_add_relu_cpu(self):
954        self._qconv2d_add_cpu_test_helper(use_relu=True)
955
956    @skipIfNoDynamoSupport
957    @skipIfNoONEDNNBF16
958    @skipIfNoONEDNN
959    @skipIfRocm
960    def test_qconv2d_add_relu_int8_mixed_bf16(self):
961        self._qconv2d_add_cpu_test_helper(use_relu=True, int8_mixed_bf16=True)
962
963    @skipIfNoDynamoSupport
964    @skipIfNoONEDNN
965    @skipIfRocm
966    def test_qconv2d_add_broadcast_shapes_cpu(self):
967        r"""
968        This testcase will quantize Conv2d->add pattern using broadcast shape inputs.
969        Conv2d->Add fusion will fail for the broadcast shape inputs case.
970        """
971
972        class M(torch.nn.Module):
973            def __init__(self, use_bias):
974                super().__init__()
975                self.conv = torch.nn.Conv2d(32, 32, kernel_size=3, stride=1)
976
977            def forward(self, x1, x2):
978                return torch.add(self.conv(x1), x2)
979
980        bias_list = [True, False]
981        for bias in bias_list:
982            mod = M(bias).eval()
983            x1 = torch.randn((2, 32, 9, 9))
984            x2 = torch.randn((2, 32, 1, 1))
985
986            def matcher_check_fn():
987                # 1. Dequant-Conv2D pattern matched in quantization weight prepack * 1
988                self.assertEqual(
989                    counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1
990                )
991                # 2. Qconv2d Binary Unary fusion in post-grad fusion pass * 0
992                self.assertEqual(
993                    counters["inductor"]["qconv2d_binary_matcher_count"], 0
994                )
995
996            self._test_common(
997                mod,
998                (x1, x2),
999                check_quantization=True,
1000                matcher_check_fn=matcher_check_fn,
1001            )
1002
1003    @skipIfNoDynamoSupport
1004    @skipIfNoONEDNN
1005    @skipIfRocm
1006    def test_qconv2d_add_2(self):
1007        r"""
1008        This testcase prevents this pattern be matched as a conv_binary fusion by mistake.
1009                Conv(X)  3
1010                    \   /
1011                     Add
1012        We see this pattern in Mobilenet v3 large which add is decomposed from torch.nn.Hardswish or torch.nn.Hardsigmoid.
1013        """
1014
1015        class M(torch.nn.Module):
1016            def __init__(
1017                self,
1018                post_op,
1019            ):
1020                super().__init__()
1021                self.conv = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
1022                self.post_op = post_op
1023
1024            def forward(self, x):
1025                return self.post_op(self.conv(x))
1026
1027        for post_op in [
1028            torch.nn.Hardswish(inplace=True),
1029            torch.nn.Hardsigmoid(inplace=True),
1030        ]:
1031            mod = M(post_op).eval()
1032            v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(
1033                1
1034            )
1035
1036            def matcher_check_fn():
1037                # Shouldn't hit conv binary fusion
1038                self.assertEqual(
1039                    counters["inductor"]["qconv2d_binary_matcher_count"], 0
1040                )
1041
1042            self._test_common(
1043                mod,
1044                (v,),
1045                check_quantization=True,
1046                matcher_check_fn=matcher_check_fn,
1047            )
1048
1049    @skipIfNoDynamoSupport
1050    @skipIfNoONEDNN
1051    @skipIfRocm
1052    def test_qconv2d_add_3(self):
1053        r"""
1054        This testcase will test below model:
1055             x
1056           /   \
1057        conv1  maxpool
1058          \    /   \
1059           add    conv2
1060            \     /
1061              cat
1062        Based on default recipe of x86InductorQuantizer, we will see this pattern after convert:
1063        qconv1    maxpool
1064         \           |
1065          \         q1
1066           \       /   \
1067            \     dq1  qconv2
1068             \   /
1069              add
1070               |
1071               q2
1072        Since q1 has 2 users and qconv2 is not ancestor node of qconv1, we shouldn't fuse:
1073                int8
1074                 /
1075        qconv1 dq1
1076           \   /
1077            add
1078             |
1079             q2
1080             |
1081            int8
1082        Instead we can match and fuse this pattern into qconv_binary:
1083        qconv1  fp32
1084            \   /
1085             add
1086              |
1087             fp32
1088        """
1089
1090        class M(torch.nn.Module):
1091            def __init__(
1092                self,
1093            ):
1094                super().__init__()
1095                self.conv1 = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1)
1096                self.conv2 = torch.nn.Conv2d(3, 3, kernel_size=1, stride=1)
1097                self.maxpool = torch.nn.MaxPool2d(
1098                    kernel_size=3, stride=1, padding=0, dilation=1
1099                )
1100
1101            def forward(self, x):
1102                tmp1 = self.conv1(x)
1103                tmp2 = self.maxpool(x)
1104                add = torch.add(tmp1, tmp2)
1105                tmp3 = self.conv2(tmp2)
1106                return torch.cat((add, tmp3), dim=1)
1107
1108        mod = M().eval()
1109        v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1)
1110
1111        def matcher_check_fn():
1112            self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_count"], 1)
1113            # The matched qconv binary pattern should have 2 nodes [qconv, add]
1114            # instead of 11 which has dequant in binary input and output quant
1115            self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 2)
1116
1117        self._test_common(
1118            mod,
1119            (v,),
1120            check_quantization=True,
1121            matcher_check_fn=matcher_check_fn,
1122        )
1123
1124    @skipIfNoDynamoSupport
1125    @skipIfNoONEDNN
1126    @skipIfRocm
1127    def test_qat_qconv2d(self):
1128        r"""
1129        This testcase will quantize a single Conv2d module with qat flow.
1130        """
1131
1132        class M(torch.nn.Module):
1133            def __init__(
1134                self,
1135                **kwargs,
1136            ):
1137                super().__init__()
1138                self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1)
1139                self.bn = torch.nn.BatchNorm2d(128)
1140
1141            def forward(self, x):
1142                return self.bn(self.conv(x))
1143
1144        mod = M().train()
1145        v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1)
1146
1147        def matcher_check_fn():
1148            # 1. Dequant-conv pattern matched in quantization weight prepack * 1
1149            #    [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
1150            self.assertEqual(
1151                counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1
1152            )
1153            self.assertEqual(
1154                counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 4
1155            )
1156            # 2. QConv2D Unary fusion in post-grad fusion pass * 1
1157            #    [qconv2d_pointwise_default, quantize_per_tensor]
1158            self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 1)
1159            self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_nodes"], 2)
1160
1161        self._test_common(
1162            mod,
1163            (v,),
1164            check_quantization=True,
1165            is_qat=True,
1166            matcher_check_fn=matcher_check_fn,
1167        )
1168
1169    def _qat_qconv2d_unary_cpu_test_helper(
1170        self,
1171        unary_op=torch.nn.ReLU(),
1172    ):
1173        class M(torch.nn.Module):
1174            def __init__(
1175                self,
1176                **kwargs,
1177            ):
1178                super().__init__()
1179                self.conv = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1)
1180                self.unary_fn = copy.deepcopy(unary_op)
1181                self.bn = torch.nn.BatchNorm2d(3)
1182                self.conv2 = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1)
1183                self.unary_fn2 = copy.deepcopy(unary_op)
1184                self.bn2 = torch.nn.BatchNorm2d(3)
1185
1186            def forward(self, x):
1187                tmp = self.unary_fn(self.bn(self.conv(x)))
1188                return self.unary_fn2(self.bn2(self.conv2(tmp)))
1189
1190        mod = M()
1191        v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1)
1192
1193        def matcher_check_fn():
1194            # 1. Dequant-conv pattern matched in quantization weight prepack * 1
1195            #    [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
1196            self.assertEqual(
1197                counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
1198            )
1199            # 2. QConv2D Unary fusion in post-grad fusion pass * 1
1200            #    [qconv2d_pointwise_default, relu, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
1201            self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 2)
1202
1203        self._test_common(
1204            mod,
1205            (v,),
1206            check_quantization=True,
1207            is_qat=True,
1208            matcher_check_fn=matcher_check_fn,
1209        )
1210
1211    @skipIfNoDynamoSupport
1212    @skipIfNoONEDNN
1213    @skipIfRocm
1214    def test_qat_qconv2d_relu(self):
1215        r"""
1216        This testcase will quantize Conv2d->ReLU pattern with qat flow.
1217        """
1218
1219        self._qat_qconv2d_unary_cpu_test_helper()
1220
1221    @skipIfNoDynamoSupport
1222    @skipIfNoONEDNN
1223    @skipIfRocm
1224    def test_qat_qconv2d_relu6(self):
1225        r"""
1226        This testcase will quantize Conv2d->ReLU6 pattern with qat flow.
1227        """
1228        self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.ReLU6())
1229
1230    @skipIfNoDynamoSupport
1231    @skipIfNoONEDNN
1232    @skipIfRocm
1233    def test_qat_qconv2d_hardtanh(self):
1234        r"""
1235        This testcase will quantize Conv2d->Hardtanh pattern with qat flow.
1236        """
1237        self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardtanh())
1238
1239    @skipIfNoDynamoSupport
1240    @skipIfNoONEDNN
1241    @skipIfRocm
1242    def test_qat_qconv2d_silu(self):
1243        r"""
1244        This testcase will quantize Conv2d->SiLU pattern with qat flow.
1245        """
1246        self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.SiLU())
1247
1248    @skipIfNoDynamoSupport
1249    @skipIfNoONEDNN
1250    @skipIfRocm
1251    def test_qat_qconv2d_hardswish(self):
1252        r"""
1253        This testcase will quantize Conv2d->Hardswish pattern with qat flow.
1254        """
1255        self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardswish())
1256
1257    @skipIfNoDynamoSupport
1258    @skipIfNoONEDNN
1259    @skipIfRocm
1260    def test_qat_qconv2d_add(self):
1261        r"""
1262        This testcase will quantize a Conv2d->Add pattern as:
1263                 X
1264               /   \
1265        Conv1(X)   Conv2(X)
1266               \   /
1267                Add
1268                 |
1269                 Y
1270        """
1271
1272        class M(torch.nn.Module):
1273            def __init__(
1274                self,
1275                **kwargs,
1276            ):
1277                super().__init__()
1278                self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
1279                self.bn1 = torch.nn.BatchNorm2d(6)
1280                self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
1281                self.bn2 = torch.nn.BatchNorm2d(6)
1282
1283            def forward(self, x):
1284                x1 = self.bn1(self.conv1(x))
1285                x2 = self.bn2(self.conv2(x))
1286                return x1 + x2
1287
1288        mod = M().train()
1289        v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1)
1290
1291        def matcher_check_fn():
1292            # 1. Dequant-conv pattern matched in quantization weight prepack * 2
1293            #    [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
1294            self.assertEqual(
1295                counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
1296            )
1297            self.assertEqual(
1298                counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 8
1299            )
1300            # 2. Qconv2d Binary fusion in post-grad fusion pass * 1
1301            #    [qconv2d_pointwise_default_1, dequantize_per_tensor, add_3, quantize_per_tensor]
1302            self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_count"], 1)
1303            self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 4)
1304
1305        self._test_common(
1306            mod,
1307            (v,),
1308            check_quantization=True,
1309            is_qat=True,
1310            matcher_check_fn=matcher_check_fn,
1311        )
1312
1313    @skipIfNoDynamoSupport
1314    @skipIfNoONEDNN
1315    @skipIfRocm
1316    def test_qat_qconv2d_add_relu(self):
1317        r"""
1318        This testcase will quantize a Conv2d->Add->ReLU pattern as:
1319                 X
1320               /   \
1321        Conv1(X)   Conv2(X)
1322               \   /
1323                Add
1324                 |
1325                ReLU
1326                 |
1327                 Y
1328        """
1329
1330        class M(torch.nn.Module):
1331            def __init__(
1332                self,
1333                **kwargs,
1334            ):
1335                super().__init__()
1336                self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
1337                self.bn1 = torch.nn.BatchNorm2d(6)
1338                self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
1339                self.bn2 = torch.nn.BatchNorm2d(6)
1340                self.relu = torch.nn.ReLU()
1341
1342            def forward(self, x):
1343                x1 = self.bn1(self.conv1(x))
1344                x2 = self.bn2(self.conv2(x))
1345                return self.relu(x1 + x2)
1346
1347        mod = M().train()
1348        v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1)
1349
1350        def matcher_check_fn():
1351            # 1. Dequant-conv pattern matched in quantization weight prepack * 2
1352            #    [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
1353            self.assertEqual(
1354                counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
1355            )
1356            self.assertEqual(
1357                counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 8
1358            )
1359            # 2. Qconv2d Binary fusion in post-grad fusion pass * 1
1360            #    [qconv2d_pointwise_default_1, dequantize_per_tensor, add_3, relu, quantize_per_tensor]
1361            self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_count"], 1)
1362            self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 5)
1363
1364        self._test_common(
1365            mod,
1366            (v,),
1367            check_quantization=True,
1368            is_qat=True,
1369            matcher_check_fn=matcher_check_fn,
1370        )
1371
1372    @skipIfNoDynamoSupport
1373    @skipIfNoONEDNN
1374    @skipIfRocm
1375    def test_qconv2d_dequant_promotion_cpu(self):
1376        r"""
1377        This testcase tests if dequant node before conv2d is promoted correctly:
1378                 X
1379                 |
1380              Conv1(X)
1381               /   \
1382        Conv2(X)   Conv3(X)
1383               \   /
1384                Add
1385                 |
1386                 Y
1387        """
1388
1389        class M(torch.nn.Module):
1390            def __init__(
1391                self,
1392                **kwargs,
1393            ):
1394                super().__init__()
1395                self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
1396                self.conv2 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1)
1397                self.conv3 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1)
1398
1399            def forward(self, x):
1400                temp = self.conv1(x)
1401                temp = self.conv2(temp) + self.conv3(temp)
1402                return temp
1403
1404        mod = M().eval()
1405        v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1)
1406
1407        def matcher_check_fn():
1408            # 1. Dequant pattern matcher for dequant promotion * 1
1409            #    [dequantize_per_tensor]
1410            self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1)
1411            self.assertEqual(counters["inductor"]["dequant_promotion_matcher_nodes"], 1)
1412            # 2. Dequant-conv pattern matched in quantization weight prepack * 3
1413            #    [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
1414            self.assertEqual(
1415                counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 3
1416            )
1417            self.assertEqual(
1418                counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 12
1419            )
1420            # 3. Qconv2d Binary fusion in post-grad fusion pass * 1
1421            #    [qconv2d_pointwise_default_1, add_3]
1422            self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_count"], 1)
1423            self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 2)
1424
1425        self._test_common(
1426            mod,
1427            (v,),
1428            check_quantization=True,
1429            matcher_check_fn=matcher_check_fn,
1430        )
1431
1432    def _qlinear_cpu_test_helper(
1433        self,
1434        inputs,
1435        int8_mixed_bf16=False,
1436        do_permute=False,
1437        matcher_check_fn=None,
1438        bias=True,
1439        is_dynamic=False,
1440        is_qat=False,
1441    ):
1442        class M(torch.nn.Module):
1443            def __init__(self, use_bias, do_permute=False):
1444                super().__init__()
1445                self.linear = torch.nn.Linear(4, 3, use_bias)
1446                self.linear2 = torch.nn.Linear(3, 4, use_bias)
1447                self.do_permute = do_permute
1448
1449            def forward(self, x):
1450                if self.do_permute:
1451                    x = torch.reshape(torch.permute(x, (0, 2, 3, 1)), (2, 12, 4))
1452                return self.linear2(self.linear(x))
1453
1454        mod = M(bias, do_permute=do_permute).eval()
1455
1456        def _default_matcher_check_fn():
1457            self.assertEqual(
1458                counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2
1459            )
1460
1461        self._test_common(
1462            mod,
1463            inputs,
1464            check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
1465            check_quantization=True,
1466            matcher_check_fn=matcher_check_fn
1467            if matcher_check_fn is not None
1468            else _default_matcher_check_fn,
1469            is_qat=is_qat,
1470            is_dynamic=is_dynamic,
1471        )
1472
1473    @skipIfNoDynamoSupport
1474    @skipIfNoONEDNN
1475    @skipIfRocm
1476    def test_qlinear_cpu(self):
1477        r"""
1478        This testcase will quantize a single Linear Moduel.
1479        """
1480        for bias in [True, False]:
1481            self._qlinear_cpu_test_helper((torch.randn((2, 4)),), bias=bias)
1482
1483    @skipIfNoDynamoSupport
1484    @skipIfNoONEDNN
1485    @skipIfRocm
1486    def test_dynamic_qlinear_cpu(self):
1487        r"""
1488        This testcase will quantize a single Linear Moduel.
1489        """
1490        for bias in [True, False]:
1491            self._qlinear_cpu_test_helper(
1492                (torch.randn((2, 4)),), bias=bias, is_dynamic=True
1493            )
1494
1495    @skipIfNoDynamoSupport
1496    @skipIfNoONEDNN
1497    @skipIfRocm
1498    def test_dynamic_qlinear_qat_cpu(self):
1499        r"""
1500        This testcase will quantize a single Linear Moduel.
1501        """
1502        for bias in [True, False]:
1503            self._qlinear_cpu_test_helper(
1504                (torch.randn((2, 4)),), bias=bias, is_dynamic=True, is_qat=True
1505            )
1506
1507    @skipIfNoDynamoSupport
1508    @skipIfNoONEDNN
1509    @skipIfRocm
1510    def test_dynamic_qlinear_input_dim_exceeds_2(self):
1511        r"""
1512        This testcase will quantize a single Linear Moduel.
1513        """
1514        for bias in [True, False]:
1515            self._qlinear_cpu_test_helper(
1516                (torch.randn((2, 3, 4)),), bias=bias, is_dynamic=True
1517            )
1518
1519    @skipIfNoDynamoSupport
1520    @skipIfNoONEDNNBF16
1521    @skipIfNoONEDNN
1522    @skipIfRocm
1523    def test_qlinear_int8_mixed_bf16(self):
1524        r"""
1525        This testcase will quantize a single Linear Moduel with int8_mixed_bf16 quantization.
1526        """
1527        for bias in [True, False]:
1528            self._qlinear_cpu_test_helper(
1529                (torch.randn((2, 4)),), int8_mixed_bf16=True, bias=bias
1530            )
1531
1532    @skipIfNoDynamoSupport
1533    @skipIfNoONEDNN
1534    @skipIfRocm
1535    def test_qlinear_input_dim_exceeds_2(self):
1536        r"""
1537        This testcase will quantize a single Linear Moduel.
1538        """
1539        for bias in [True, False]:
1540            self._qlinear_cpu_test_helper((torch.randn((2, 3, 4)),), bias=bias)
1541
1542    @skipIfNoDynamoSupport
1543    @skipIfNoONEDNNBF16
1544    @skipIfNoONEDNN
1545    @skipIfRocm
1546    def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2(self):
1547        r"""
1548        This testcase will quantize a single Linear Moduel with int8_mixed_bf16 quantization.
1549        """
1550        for bias in [True, False]:
1551            self._qlinear_cpu_test_helper(
1552                (torch.randn((2, 3, 4)),), int8_mixed_bf16=True, bias=bias
1553            )
1554
1555    @skipIfNoDynamoSupport
1556    @skipIfNoONEDNN
1557    @skipIfRocm
1558    def test_qlinear_input_dim_exceeds_2_and_not_contiguous(self):
1559        r"""
1560        This testcase will quantize a single Linear Module.
1561        * Input dim exceeds 2
1562        * Input not contiguous
1563        """
1564        for bias in [True, False]:
1565
1566            def matcher_check_fn():
1567                self.assertEqual(
1568                    counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2
1569                )
1570                self.assertEqual(
1571                    counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
1572                    13 if bias else 12,
1573                )
1574
1575            self._qlinear_cpu_test_helper(
1576                (torch.randn((2, 4, 3, 4)),),
1577                do_permute=True,
1578                matcher_check_fn=matcher_check_fn,
1579                bias=bias,
1580            )
1581
1582    @skipIfNoDynamoSupport
1583    @skipIfNoONEDNNBF16
1584    @skipIfNoONEDNN
1585    @skipIfRocm
1586    def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2_and_not_contiguous(self):
1587        r"""
1588        This testcase will quantize a single Linear Module for int8_bf16.
1589        * Input dim exceeds 2
1590        * Input not contiguous
1591        """
1592        for bias in [True, False]:
1593
1594            def matcher_check_fn():
1595                self.assertEqual(
1596                    counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2
1597                )
1598                self.assertEqual(
1599                    counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
1600                    17 if bias else 16,
1601                )
1602
1603            self._qlinear_cpu_test_helper(
1604                (torch.randn((2, 4, 3, 4)),),
1605                int8_mixed_bf16=True,
1606                do_permute=True,
1607                matcher_check_fn=matcher_check_fn,
1608                bias=bias,
1609            )
1610
1611    def _qlinear_unary_cpu_test_helper(
1612        self, inputs, unary_op=torch.nn.ReLU(), int8_mixed_bf16=False
1613    ):
1614        class M(torch.nn.Module):
1615            def __init__(self, use_bias):
1616                super().__init__()
1617                self.linear = torch.nn.Linear(4, 4, use_bias)
1618                self.unary_fn = copy.deepcopy(unary_op)
1619                self.linear2 = torch.nn.Linear(4, 4, use_bias)
1620                self.unary_fn2 = copy.deepcopy(unary_op)
1621
1622            def forward(self, x):
1623                tmp = self.unary_fn(self.linear(x))
1624                return self.unary_fn2(self.linear2(tmp))
1625
1626        bias_list = [True, False]
1627        for bias in bias_list:
1628            mod = M(bias).eval()
1629
1630            def matcher_check_fn():
1631                # 1. dequant-linear pattern matched in quantization weight prepack
1632                self.assertEqual(
1633                    counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2
1634                )
1635                # 2. QLinear Unary fusion in post-grad fusion pass
1636                self.assertEqual(counters["inductor"]["qlinear_unary_matcher_count"], 2)
1637
1638            self._test_common(
1639                mod,
1640                inputs,
1641                check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
1642                check_quantization=True,
1643                matcher_check_fn=matcher_check_fn,
1644            )
1645
1646    @skipIfNoDynamoSupport
1647    @skipIfNoONEDNN
1648    @skipIfRocm
1649    def test_qlinear_relu_cpu(self):
1650        r"""
1651        This testcase will quantize a Linear->ReLU pattern.
1652        """
1653        self._qlinear_unary_cpu_test_helper((torch.randn((2, 4)),))
1654
1655    @skipIfNoDynamoSupport
1656    @skipIfNoONEDNNBF16
1657    @skipIfNoONEDNN
1658    @skipIfRocm
1659    def test_qlinear_relu_int8_mixed_bf16(self):
1660        r"""
1661        This testcase will quantize a Linear->ReLU pattern with int8_mixed_bf16 quantization.
1662        """
1663        self._qlinear_unary_cpu_test_helper(
1664            (torch.randn((2, 4)),), int8_mixed_bf16=True
1665        )
1666
1667    @skipIfNoDynamoSupport
1668    @skipIfNoONEDNN
1669    @skipIfRocm
1670    def test_qlinear_relu_input_dim_exceeds_2(self):
1671        r"""
1672        This testcase will quantize a Linear->ReLU pattern.
1673        """
1674        self._qlinear_unary_cpu_test_helper((torch.randn((2, 3, 4)),))
1675
1676    @skipIfNoDynamoSupport
1677    @skipIfNoONEDNNBF16
1678    @skipIfNoONEDNN
1679    @skipIfRocm
1680    def test_qlinear_relu_int8_mixed_bf16_input_dim_exceeds_2(self):
1681        r"""
1682        This testcase will quantize a Linear->ReLU pattern with int8_mixed_bf16 quantization.
1683        """
1684        self._qlinear_unary_cpu_test_helper(
1685            (torch.randn((2, 3, 4)),), int8_mixed_bf16=True
1686        )
1687
1688    @skipIfNoDynamoSupport
1689    @skipIfNoONEDNN
1690    @skipIfRocm
1691    def test_qlinear_gelu_cpu(self):
1692        r"""
1693        This testcase will quantize a Linear->GELU pattern.
1694        """
1695        for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]:
1696            self._qlinear_unary_cpu_test_helper((torch.randn((2, 4)),), gelu)
1697
1698    @skipIfNoDynamoSupport
1699    @skipIfNoONEDNNBF16
1700    @skipIfNoONEDNN
1701    @skipIfRocm
1702    def test_qlinear_gelu_int8_mixed_bf16(self):
1703        r"""
1704        This testcase will quantize a Linear->GELU pattern with int8_mixed_bf16 quantization.
1705        """
1706        for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]:
1707            self._qlinear_unary_cpu_test_helper(
1708                (torch.randn((2, 4)),), gelu, int8_mixed_bf16=True
1709            )
1710
1711    def _qlinear_add_cpu_test_helper(self, use_relu=False, int8_mixed_bf16=False):
1712        r"""
1713        This testcase will quantize two consecutive Linear->Add(->relu) patterns as:
1714                 X
1715               /   \
1716        linear(X)   linear(X)
1717               \   /
1718                Add
1719                 |
1720           Optional(relu)
1721               /   \
1722        linear(X)   linear(X)
1723               \   /
1724                Add
1725                 |
1726           Optional(relu)
1727                 |
1728                 Y
1729        """
1730
1731        def fake_quant(x):
1732            # to produce a float32 result as extra input
1733            qlib = torch.ops.quantized_decomposed
1734            x = qlib.quantize_per_tensor.default(x, 0.0166785, 42, 0, 255, torch.uint8)
1735            x = qlib.dequantize_per_tensor.default(
1736                x, 0.0166785, 42, 0, 255, torch.uint8
1737            )
1738            return x
1739
1740        class M(torch.nn.Module):
1741            def __init__(
1742                self,
1743                add_fn,
1744                use_relu,
1745                fake_quant_before_extra_input,
1746            ):
1747                super().__init__()
1748                self.linear1 = torch.nn.Linear(4, 4)
1749                self.linear2 = torch.nn.Linear(4, 4)
1750                self.add_fn = add_fn
1751                self.relu = torch.nn.ReLU()
1752                self.linear3 = torch.nn.Linear(4, 4)
1753                self.linear4 = torch.nn.Linear(4, 4)
1754                self.add_fn2 = add_fn
1755                self.relu2 = torch.nn.ReLU()
1756                self.use_relu = use_relu
1757                self.fake_quant_before_extra_input = fake_quant_before_extra_input
1758
1759            def forward(self, x):
1760                x1 = self.linear1(x)
1761                x2 = self.linear2(x)
1762                if self.fake_quant_before_extra_input:
1763                    x2 = fake_quant(x2)
1764                tmp = self.add_fn(x1, x2)
1765                if self.use_relu:
1766                    tmp = self.relu(tmp)
1767                tmp1 = self.linear3(tmp)
1768                tmp2 = self.linear4(tmp)
1769                if self.fake_quant_before_extra_input:
1770                    tmp2 = fake_quant(tmp2)
1771                res = self.add_fn2(tmp1, tmp2)
1772                if self.use_relu:
1773                    res = self.relu2(res)
1774                return res
1775
1776        add_fn_list = [
1777            lambda x, y: x + y,
1778            lambda x, y: y + x,
1779            lambda x, y: x.add_(y),
1780            lambda x, y: y.add_(x),
1781        ]
1782        fake_quant_x2_list = [False, True] if int8_mixed_bf16 else [False]
1783        cases = itertools.product(add_fn_list, fake_quant_x2_list)
1784        for add_fn, fq_x2 in cases:
1785            mod = M(add_fn, use_relu, fq_x2).eval()
1786            v = torch.randn((4, 4), dtype=torch.float32, requires_grad=False).add(1)
1787
1788            def matcher_check_fn():
1789                # 1. Dequant-linear pattern matched in quantization weight prepack * 4
1790                self.assertEqual(
1791                    counters["inductor"]["qlinear_weight_prepack_matcher_count"], 4
1792                )
1793                # pattern = [dequant_per_tensor, (convert_dtype), dequant_per_channel, (convert_dtype), permute, addmm]
1794                nodes_per_match = 6 if int8_mixed_bf16 else 4
1795                self.assertEqual(
1796                    counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
1797                    4 * nodes_per_match,
1798                )
1799                # 2. Qlinear Binary Unary fusion in post-grad fusion pass * 2
1800                self.assertEqual(
1801                    counters["inductor"]["qlinear_binary_matcher_count"], 2
1802                )
1803                # Two linear-binary patterns are matched
1804                # matched patter1 = [qlinear, add, (convert dtype), (relu), quantize_per_tensor]
1805                # matched patter2 = [qlinear, add, (convert dtype), (relu)]
1806                # If add_fn is x.add_(y), x is bf16 and y is fp32, there is a to_bf16 node after binary
1807                to_bf16_after_binary = 2 * (add_fn == add_fn_list[2] and fq_x2)
1808                self.assertEqual(
1809                    counters["inductor"]["qlinear_binary_matcher_nodes"],
1810                    5 + 2 * use_relu + to_bf16_after_binary,
1811                )
1812
1813            for is_qat in [False, True]:
1814                self._test_common(
1815                    mod,
1816                    (v,),
1817                    check_quantization=True,
1818                    check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
1819                    matcher_check_fn=matcher_check_fn,
1820                    is_qat=is_qat,
1821                )
1822                if torch._inductor.config.cpp_wrapper:
1823                    # For CPP wrapper
1824                    self._test_code_common(
1825                        mod,
1826                        (v,),
1827                        [
1828                            "op_qlinear_pointwise.call",
1829                            "op_qlinear_pointwise_binary.call",
1830                        ],
1831                        [],
1832                        check_quantization=True,
1833                        num_include_ops=[2, 2],
1834                    )
1835                else:
1836                    # For python wrapper
1837                    self._test_code_common(
1838                        mod,
1839                        (v,),
1840                        [
1841                            "torch.ops.onednn.qlinear_pointwise.default",
1842                            "torch.ops.onednn.qlinear_pointwise.binary",
1843                        ],
1844                        [],
1845                        check_quantization=True,
1846                        num_include_ops=[2, 2],
1847                    )
1848
1849    @skipIfNoDynamoSupport
1850    @skipIfNoONEDNN
1851    @skipIfRocm
1852    def test_qlinear_add_cpu(self):
1853        self._qlinear_add_cpu_test_helper()
1854
1855    @skipIfNoDynamoSupport
1856    @skipIfNoONEDNNBF16
1857    @skipIfNoONEDNN
1858    @skipIfRocm
1859    def test_qlinear_add_int8_mixed_bf16(self):
1860        self._qlinear_add_cpu_test_helper(int8_mixed_bf16=True)
1861
1862    @skipIfNoDynamoSupport
1863    @skipIfNoONEDNN
1864    @skipIfRocm
1865    def test_qlinear_add_relu_cpu(self):
1866        self._qlinear_add_cpu_test_helper(use_relu=True)
1867
1868    @skipIfNoDynamoSupport
1869    @skipIfNoONEDNNBF16
1870    @skipIfNoONEDNN
1871    @skipIfRocm
1872    def test_qlinear_add_relu_int8_mixed_bf16(self):
1873        self._qlinear_add_cpu_test_helper(use_relu=True, int8_mixed_bf16=True)
1874
1875    def _qlinear_dequant_promotion_cpu_test_helper(
1876        self,
1877        inputs,
1878        int8_mixed_bf16=False,
1879        is_dynamic=False,
1880        matcher_check_fn=None,
1881    ):
1882        class M(torch.nn.Module):
1883            def __init__(
1884                self,
1885                **kwargs,
1886            ):
1887                super().__init__()
1888                self.linear1 = torch.nn.Linear(4, 4)
1889                self.linear2 = torch.nn.Linear(4, 4)
1890                self.linear3 = torch.nn.Linear(4, 4)
1891
1892            def forward(self, x):
1893                temp = self.linear1(x)
1894                temp = self.linear2(temp) + self.linear3(temp)
1895                return temp
1896
1897        mod = M().eval()
1898
1899        def default_matcher_check_fn():
1900            # 1. Dequant pattern matcher for dequant promotion * 1
1901            self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1)
1902            # 2. dequant-linear pattern matched in quantization weight prepack * 3
1903            self.assertEqual(
1904                counters["inductor"]["qlinear_weight_prepack_matcher_count"], 3
1905            )
1906            # 3. QLinear Unary fusion in post-grad fusion pass * 1
1907            self.assertEqual(counters["inductor"]["qlinear_unary_matcher_count"], 1)
1908
1909        self._test_common(
1910            mod,
1911            inputs,
1912            check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
1913            check_quantization=True,
1914            matcher_check_fn=matcher_check_fn
1915            if matcher_check_fn is not None
1916            else default_matcher_check_fn,
1917            is_dynamic=is_dynamic,
1918        )
1919
1920    @skipIfNoDynamoSupport
1921    @skipIfNoONEDNN
1922    @skipIfRocm
1923    def test_qlinear_dequant_promotion_cpu(self):
1924        r"""
1925        This testcase test if dequant node before linear is promoted correctly:
1926                  X
1927                  |
1928               Linear1(X)
1929                /   \
1930        Linear2(X)   Linear3(X)
1931                \   /
1932                 Add
1933                  |
1934                  Y
1935        """
1936        self._qlinear_dequant_promotion_cpu_test_helper((torch.randn((2, 4)),))
1937
1938    @skipIfNoDynamoSupport
1939    @skipIfNoONEDNNBF16
1940    @skipIfNoONEDNN
1941    @skipIfRocm
1942    def test_qlinear_dequant_promotion_int8_mixed_bf16(self):
1943        r"""
1944        Test with int8_mixed_bf16 quantization.
1945        This testcase test if dequant node before linear is promoted correctly:
1946                  X
1947                  |
1948               Linear1(X)
1949                /   \
1950        Linear2(X)   Linear3(X)
1951                \   /
1952                 Add
1953                  |
1954                  Y
1955        """
1956        self._qlinear_dequant_promotion_cpu_test_helper(
1957            (torch.randn((2, 4)),), int8_mixed_bf16=True
1958        )
1959
1960    @skipIfNoDynamoSupport
1961    @skipIfNoONEDNN
1962    @skipIfRocm
1963    def test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2(self):
1964        r"""
1965        This testcase test if dequant node before linear is promoted correctly:
1966                  X
1967                  |
1968               Linear1(X)
1969                /   \
1970        Linear2(X)   Linear3(X)
1971                \   /
1972                 Add
1973                  |
1974                  Y
1975        """
1976        self._qlinear_dequant_promotion_cpu_test_helper((torch.randn((2, 3, 4)),))
1977
1978    @skipIfNoDynamoSupport
1979    @skipIfNoONEDNNBF16
1980    @skipIfNoONEDNN
1981    @skipIfRocm
1982    def test_qlinear_dequant_promotion_int8_mixed_bf16_input_dim_exceeds_2(self):
1983        r"""
1984        Test with int8_mixed_bf16 quantization.
1985        This testcase test if dequant node before linear is promoted correctly:
1986                  X
1987                  |
1988               Linear1(X)
1989                /   \
1990        Linear2(X)   Linear3(X)
1991                \   /
1992                 Add
1993                  |
1994                  Y
1995        """
1996        self._qlinear_dequant_promotion_cpu_test_helper(
1997            (torch.randn((2, 3, 4)),), int8_mixed_bf16=True
1998        )
1999
2000    @skipIfNoDynamoSupport
2001    @skipIfNoONEDNN
2002    @skipIfRocm
2003    def test_qlinear_dequant_promotion_dynamic_cpu(self):
2004        r"""
2005        This testcase test if dequant node before linear is promoted correctly:
2006                  X
2007                  |
2008               Linear1(X)
2009                /   \
2010        Linear2(X)   Linear3(X)
2011                \   /
2012                 Add
2013                  |
2014                  Y
2015        """
2016
2017        def matcher_check_fn():
2018            # 1. Dequant pattern matcher for dequant promotion * 1
2019            self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1)
2020            # 2. dequant-linear pattern matched in quantization weight prepack * 3
2021            self.assertEqual(
2022                counters["inductor"]["qlinear_weight_prepack_matcher_count"], 3
2023            )
2024
2025        self._qlinear_dequant_promotion_cpu_test_helper(
2026            (torch.randn((2, 4)),),
2027            matcher_check_fn=matcher_check_fn,
2028            is_dynamic=True,
2029        )
2030
2031    @skipIfNoDynamoSupport
2032    @skipIfNoONEDNN
2033    @skipIfRocm
2034    def test_qlinear_mul_cpu(self):
2035        r"""
2036        This testcase will quantize a Linear->Mul pattern.
2037        """
2038
2039        class M(torch.nn.Module):
2040            def __init__(self, use_bias):
2041                super().__init__()
2042                self.linear = torch.nn.Linear(4, 5, use_bias)
2043
2044            def forward(self, x1, x2):
2045                return torch.mul(self.linear(x1), x2)
2046
2047        bias_list = [True, False]
2048        for bias in bias_list:
2049            mod = M(bias).eval()
2050            x1 = torch.randn((2, 4))
2051            x2 = torch.randn((2, 5))
2052
2053            def matcher_check_fn():
2054                self.assertEqual(
2055                    counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1
2056                )
2057
2058            self._test_common(
2059                mod,
2060                (x1, x2),
2061                check_quantization=True,
2062                matcher_check_fn=matcher_check_fn,
2063            )
2064
2065    @skipIfNoDynamoSupport
2066    @skipIfRocm
2067    def test_qmaxpool2d(self):
2068        r"""
2069        This testcase will quantize Conv2d->ReLU->MaxPool2d pattern.
2070        """
2071
2072        class M(torch.nn.Module):
2073            def __init__(
2074                self,
2075                kwargs,
2076            ):
2077                super().__init__()
2078                self.conv = torch.nn.Conv2d(
2079                    3, 64, 7, bias=True, stride=2, padding=3, dilation=1
2080                )
2081                self.relu = torch.nn.ReLU()
2082                self.maxpool = torch.nn.MaxPool2d(3, **kwargs)
2083
2084            def forward(self, x):
2085                return self.maxpool(self.relu(self.conv(x)))
2086
2087        kwargs_list = [
2088            {"stride": 2},
2089            {"stride": 2, "padding": 1},
2090            {"stride": 2, "padding": 1, "dilation": 1},
2091            {"stride": 2, "padding": 1, "dilation": 1, "ceil_mode": False},
2092        ]
2093        for kwargs in kwargs_list:
2094            mod = M(kwargs).eval()
2095            v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(
2096                1
2097            )
2098
2099            def matcher_check_fn():
2100                self.assertEqual(counters["inductor"]["qmaxpool2d_matcher_count"], 1)
2101                self.assertEqual(
2102                    counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1
2103                )
2104                self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 1)
2105
2106            self._test_common(
2107                mod,
2108                (v,),
2109                check_quantization=True,
2110                matcher_check_fn=matcher_check_fn,
2111            )
2112
2113    @skipIfNoDynamoSupport
2114    @skipIfRocm
2115    def test_qflatten(self):
2116        r"""
2117        This testcase will quantize Conv2d->AdaptiveAvgPool2d->flatten pattern.
2118        """
2119
2120        class M(torch.nn.Module):
2121            def __init__(
2122                self,
2123            ):
2124                super().__init__()
2125                self.conv = torch.nn.Conv2d(
2126                    3, 64, 7, bias=True, stride=2, padding=3, dilation=1
2127                )
2128                self.relu = torch.nn.ReLU()
2129                self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
2130
2131            def forward(self, x):
2132                return torch.flatten(
2133                    self.adaptive_avg_pool2d(self.relu(self.conv(x))), 1
2134                )
2135
2136        mod = M().eval()
2137        v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1)
2138
2139        def matcher_check_fn():
2140            self.assertEqual(counters["inductor"]["qreshape_matcher_count"], 1)
2141
2142        self._test_common(
2143            mod,
2144            (v,),
2145            check_quantization=True,
2146            matcher_check_fn=matcher_check_fn,
2147        )
2148
2149    @skipIfNoDynamoSupport
2150    @skipIfRocm
2151    def test_qcat(self):
2152        r"""
2153        This testcase will quantize cat based pattern:
2154                X
2155             /     \
2156        Conv1(X)  Pow(x)
2157            \        \
2158             \     Conv2(X)
2159              \    /
2160               Cat
2161                |
2162                Y
2163        """
2164
2165        class M(torch.nn.Module):
2166            def __init__(
2167                self,
2168            ):
2169                super().__init__()
2170                self.conv = torch.nn.Conv2d(
2171                    3, 64, 7, bias=True, stride=2, padding=3, dilation=1
2172                )
2173                self.conv2 = torch.nn.Conv2d(
2174                    3, 64, 7, bias=True, stride=2, padding=3, dilation=1
2175                )
2176
2177            def forward(self, x):
2178                temp1 = self.conv(x)
2179                temp2 = self.conv2(torch.pow(x, 2))
2180                return torch.cat((temp1, temp2), 1)
2181
2182        mod = M().eval()
2183        v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1)
2184
2185        def matcher_check_fn():
2186            self.assertEqual(counters["inductor"]["qcat_matcher_count"], 1)
2187            self.assertEqual(
2188                counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
2189            )
2190            self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 2)
2191
2192        self._test_common(
2193            mod,
2194            (v,),
2195            check_quantization=True,
2196            matcher_check_fn=matcher_check_fn,
2197        )
2198
2199    # https://github.com/pytorch/pytorch/issues/99841.
2200    def test_hardtanh_pattern_fallback(self):
2201        class Model(torch.nn.Module):
2202            def __init__(self):
2203                super().__init__()
2204                self.conv_transpose = torch.nn.ConvTranspose2d(
2205                    in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1
2206                )
2207
2208            def forward(self, x, min_value, max_value):
2209                conv_transpose_output = self.conv_transpose(x)
2210                clamp_min_output = torch.clamp_min(conv_transpose_output, min_value)
2211                clamp_max_output = torch.clamp_max(clamp_min_output, max_value)
2212                return clamp_max_output
2213
2214        # check works for min_value > max_value.
2215        min_values = [3, torch.randn(1, 32, 28, 28)]
2216        max_values = [0, torch.randn(1, 32, 28, 28)]
2217        v = torch.randn(1, 3, 28, 28)
2218        for min_value, max_value in zip(min_values, max_values):
2219            mod = Model().eval()
2220            self._test_common(mod, (v, min_value, max_value), 2, 4)
2221
2222    def test_leaky_relu_pattern_fallback(self):
2223        class Model(torch.nn.Module):
2224            def __init__(self):
2225                super().__init__()
2226                self.conv = torch.nn.Conv2d(
2227                    in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1
2228                )
2229
2230            def forward(self, x, negative_slope):
2231                conv_out = self.conv(x)
2232                return torch.where(conv_out > 0, conv_out, conv_out * negative_slope)
2233
2234        negative_slopes = [0.1, torch.randn(1, 32, 28, 28)]
2235        with torch.no_grad():
2236            v = torch.randn(1, 3, 28, 28)
2237            for negative_slope in negative_slopes:
2238                mod = Model().eval()
2239                self._test_common(mod, (v, negative_slope), 2, 5)
2240
2241    # https://github.com/pytorch/pytorch/issues/99838.
2242    def test_conv2d_add_scalar(self):
2243        class Model(torch.nn.Module):
2244            def __init__(self):
2245                super().__init__()
2246                self.conv = torch.nn.Conv2d(
2247                    in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1
2248                )
2249
2250            def forward(self, x):
2251                out_conv = self.conv(x)
2252                out = torch.add(out_conv, 1.0)
2253                return out
2254
2255        with torch.no_grad():
2256            mod = Model().eval()
2257            v = torch.randn(1, 3, 28, 28)
2258            self._test_common(mod, (v,), 1, 1)
2259
2260    def test_conv2d_binary_inplace_fusion_pass_cpu(
2261        self, include_ops=None, exclude_ops=None
2262    ):
2263        class Model_v1(torch.nn.Module):
2264            def __init__(self):
2265                super().__init__()
2266                self.conv = torch.nn.Conv2d(
2267                    in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1
2268                )
2269
2270            def forward(self, x, other):
2271                conv_out = self.conv(x)
2272                return torch.add(conv_out, other.relu())
2273
2274        class Model_v2(torch.nn.Module):
2275            def __init__(self):
2276                super().__init__()
2277                self.conv = torch.nn.Conv2d(
2278                    in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1
2279                )
2280                self.conv2 = torch.nn.Conv2d(
2281                    in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1
2282                )
2283                self.conv3 = torch.nn.Conv2d(
2284                    in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1
2285                )
2286
2287            def forward(self, x, _):
2288                conv_out1 = self.conv(x)
2289                pow_out = torch.pow(conv_out1, 2)
2290                conv_out2 = self.conv2(pow_out)
2291                conv_out3 = self.conv3(conv_out2)
2292                res = torch.add(conv_out3, pow_out)
2293                return res
2294
2295        input = torch.randn(1, 3, 28, 28).to(memory_format=torch.channels_last)
2296        others = [
2297            torch.randn(1, 32, 28, 28).to(memory_format=torch.channels_last),
2298            torch.randn(1, 32, 28, 28).to(memory_format=torch.channels_last),
2299        ]
2300        mod_v1 = Model_v1().to(memory_format=torch.channels_last).eval()
2301        mod_v2 = Model_v2().to(memory_format=torch.channels_last).eval()
2302
2303        if include_ops is None:
2304            include_ops = ["mkldnn._convolution_pointwise_.binary"]
2305        if exclude_ops is None:
2306            exclude_ops = ["mkldnn._convolution_pointwise.binary"]
2307
2308        for other, mod in zip(others, [mod_v1, mod_v2]):
2309            self._test_code_common(mod, (input, other), include_ops, exclude_ops)
2310
2311    def test_conv2d_binary_inplace_fusion_failed_cpu(
2312        self, include_ops=None, exclude_ops=None
2313    ):
2314        # Written buffer is graph input, we can't fuse inplace.
2315        class Model_v1(torch.nn.Module):
2316            def __init__(self):
2317                super().__init__()
2318                self.conv = torch.nn.Conv2d(
2319                    in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1
2320                )
2321
2322            def forward(self, x, other):
2323                conv_out = self.conv(x)
2324                return torch.add(conv_out, other)
2325
2326        # Written buffer is an alias tensor, we can't fuse inplace.
2327        class Model_v2(torch.nn.Module):
2328            def __init__(self):
2329                super().__init__()
2330                self.conv = torch.nn.Conv2d(
2331                    in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1
2332                )
2333
2334            def forward(self, x, other):
2335                conv_out = self.conv(x)
2336                return torch.add(conv_out, other[1:2, :, :, :]), other
2337
2338        class Model_v3(torch.nn.Module):
2339            def __init__(self):
2340                super().__init__()
2341                self.conv = torch.nn.Conv2d(
2342                    in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1
2343                )
2344                self.conv2 = torch.nn.Conv2d(
2345                    in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1
2346                )
2347
2348            def forward(self, x, _):
2349                pow_out = torch.pow(self.conv(x), 2)
2350                other2 = F.relu(pow_out)
2351                conv_out2 = self.conv2(pow_out)
2352                res = torch.add(conv_out2, pow_out)
2353                res = res + other2
2354                return res
2355
2356        # Written buffer is an ReinterpretView, we can't fuse inplace.
2357        class Model_v4(torch.nn.Module):
2358            def __init__(self):
2359                super().__init__()
2360                self.conv = torch.nn.Conv2d(3, 32, 3, padding=1, bias=True)
2361                self.linear = torch.nn.Linear(32 * 28, 32 * 28)
2362                self.relu = torch.nn.ReLU()
2363
2364            def forward(self, x, y):
2365                x = self.conv(self.relu(x))
2366                y = self.linear(y)
2367                y = torch.cat((y, y), 1)
2368                y = torch.ops.aten.permute.default(y, [0, 2, 1]).reshape(1, 32, 28, 28)
2369                return x + y
2370
2371        class Model_v5(torch.nn.Module):
2372            def __init__(self):
2373                super().__init__()
2374                self.conv = torch.nn.Conv2d(32, 32, 3, padding=1, bias=True)
2375                self.relu = torch.nn.ReLU()
2376
2377            def forward(self, _, x):
2378                x1 = self.relu(x)
2379                return self.conv(x1) + x1
2380
2381        input = torch.randn(1, 3, 28, 28).to(memory_format=torch.channels_last)
2382        others = [
2383            torch.randn(1, 32, 28, 28).to(memory_format=torch.channels_last),
2384            torch.randn(2, 32, 28, 28).to(memory_format=torch.channels_last),
2385            torch.randn(1, 32, 28, 28).to(memory_format=torch.channels_last),
2386            torch.randn(1, 14, 32 * 28),
2387            torch.randn(1, 32, 28, 28).to(memory_format=torch.channels_last),
2388        ]
2389        mod_v1 = Model_v1().to(memory_format=torch.channels_last).eval()
2390        mod_v2 = Model_v2().to(memory_format=torch.channels_last).eval()
2391        mod_v3 = Model_v3().to(memory_format=torch.channels_last).eval()
2392        mod_v4 = Model_v4().to(memory_format=torch.channels_last).eval()
2393        mod_v5 = Model_v5().to(memory_format=torch.channels_last).eval()
2394
2395        if include_ops is None:
2396            include_ops = ["mkldnn._convolution_pointwise.binary"]
2397        if exclude_ops is None:
2398            exclude_ops = ["mkldnn._convolution_pointwise_.binary"]
2399
2400        for other, mod in zip(others, [mod_v1, mod_v2, mod_v3, mod_v4, mod_v5]):
2401            self._test_code_common(mod, (input, other), include_ops, exclude_ops)
2402
2403    def test_conv2d_binary_fusion_failed(self):
2404        # we don't support alpha !=1 case or other has different size with conv's output.
2405        class Model(torch.nn.Module):
2406            def __init__(self):
2407                super().__init__()
2408                self.conv = torch.nn.Conv2d(
2409                    in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1
2410                )
2411
2412            def forward(self, x, other, alpha):
2413                conv_out = self.conv(x)
2414                return torch.add(conv_out, other, alpha=alpha)
2415
2416        # https://github.com/pytorch/pytorch/issues/100802.
2417        # we can't do the fusion when add's inputs are same tensor.
2418        class Model2(torch.nn.Module):
2419            def __init__(self):
2420                super().__init__()
2421                self.conv = torch.nn.Conv2d(
2422                    in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1
2423                )
2424
2425            def forward(self, x):
2426                out = self.conv(x)
2427                out = torch.add(out, out)
2428                return out
2429
2430        # https://github.com/pytorch/pytorch/issues/101374.
2431        # we can't do the fusion when add's inputs are mixed dtype.
2432        class Model3(torch.nn.Module):
2433            def __init__(self):
2434                super().__init__()
2435                self.conv = torch.nn.Conv2d(
2436                    in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1
2437                )
2438
2439            def forward(self, x):
2440                temp = self.conv(x)
2441                other = torch.ones(temp.shape, dtype=torch.double)
2442                out = torch.add(temp, other)
2443                return out
2444
2445        input = torch.randn(1, 3, 28, 28).to(memory_format=torch.channels_last)
2446        others = [
2447            torch.randn(1, 32, 28, 28).to(memory_format=torch.channels_last),
2448            torch.randn(32, 28, 28),
2449        ]
2450        include_ops = ["mkldnn._convolution_pointwise"]
2451        exclude_ops = [
2452            "mkldnn._convolution_pointwise.binary",
2453            "mkldnn._convolution_pointwise_.binary",
2454        ]
2455
2456        # case1
2457        for other, alpha in zip(others, [0.1, 1.0]):
2458            mod = Model().to(memory_format=torch.channels_last).eval()
2459            self._test_code_common(mod, (input, other, alpha), include_ops, exclude_ops)
2460        # case2:
2461        mod = Model2().to(memory_format=torch.channels_last).eval()
2462        self._test_code_common(mod, (input,), include_ops, exclude_ops)
2463        # case3:
2464        mod = Model3().to(memory_format=torch.channels_last).eval()
2465        self._test_code_common(mod, (input,), include_ops, exclude_ops)
2466
2467    def test_reproduce_99842_issue(self):
2468        class Model(torch.nn.Module):
2469            def __init__(self):
2470                super().__init__()
2471                self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
2472
2473            def forward(self, input_tensor):
2474                x = self.conv(input_tensor)
2475                x = F.relu(x + torch.ones(x.size()))
2476                return x
2477
2478        input = torch.randn(1, 3, 14, 14)
2479        mod = Model().eval()
2480        include_ops = ["mkldnn._convolution_pointwise_.binary"]
2481        self._test_code_common(mod, (input,), include_ops, [])
2482
2483    def test_reproduce_113440_issue_1(self):
2484        class Mod(torch.nn.Module):
2485            def __init__(
2486                self,
2487                add_fn,
2488                **kwargs,
2489            ):
2490                super().__init__()
2491                self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
2492                self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
2493                self.add_fn = add_fn
2494                self.relu = torch.nn.ReLU(inplace=True)
2495                self.conv3 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1)
2496                self.conv4 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1)
2497                self.add_fn2 = add_fn
2498                self.relu2 = torch.nn.ReLU(inplace=True)
2499                self.use_relu = True
2500
2501            def forward(self, x):
2502                x1 = self.conv1(x)
2503                x2 = self.conv2(x)
2504                tmp = self.add_fn(x1, x2)
2505                if self.use_relu:
2506                    tmp = self.relu(tmp)
2507                tmp1 = self.conv3(tmp)
2508                tmp2 = self.conv4(tmp)
2509                res = self.add_fn2(tmp1, tmp2)
2510                if self.use_relu:
2511                    res = self.relu2(res)
2512                return res
2513
2514        with torch.no_grad():
2515            example_inputs = (
2516                torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(
2517                    1
2518                ),
2519            )
2520            example_inputs[0].get_device()
2521            m = Mod(
2522                lambda x, y: x.add_(y),
2523            ).eval()
2524            om = torch.compile(m)
2525            om(*example_inputs)
2526            om(*example_inputs)
2527
2528    def test_reproduce_113440_issue_2(self):
2529        class Mod(torch.nn.Module):
2530            def __init__(
2531                self,
2532                add_fn,
2533                **kwargs,
2534            ):
2535                super().__init__()
2536                self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
2537                self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
2538                self.add_fn = add_fn
2539                self.relu = torch.nn.ReLU(inplace=True)
2540                self.conv3 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1)
2541                self.conv4 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1)
2542                self.add_fn2 = add_fn
2543                self.relu2 = torch.nn.ReLU(inplace=True)
2544
2545                self.conv5 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1)
2546                self.conv6 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1)
2547                self.conv7 = torch.nn.Conv2d(6, 6, kernel_size=1, stride=1)
2548                self.add_fn3 = add_fn
2549                self.relu3 = torch.nn.ReLU(inplace=True)
2550
2551                self.use_relu = True
2552
2553            def forward(self, x):
2554                x1 = self.conv1(x)
2555                x2 = self.conv2(x)
2556                tmp = self.add_fn(x1, x2)
2557                if self.use_relu:
2558                    tmp = self.relu(tmp)
2559
2560                tmp1 = self.conv3(tmp)
2561                res = self.relu2(tmp1)
2562
2563                return res
2564
2565        with torch.no_grad():
2566            example_inputs = (
2567                torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(
2568                    1
2569                ),
2570            )
2571            m = Mod(
2572                lambda x, y: x.add_(y),
2573            ).eval()
2574            om = torch.compile(m)
2575            om(*example_inputs)
2576            om(*example_inputs)
2577
2578    def test_reproduce_121253_issue(self):
2579        class Mod(torch.nn.Module):
2580            def __init__(self, weight, bias, beta, alpha):
2581                super().__init__()
2582                self.weight = weight
2583                self.bias = bias
2584                self.beta = beta
2585                self.alpha = alpha
2586
2587            def forward(self, x):
2588                return torch.addmm(
2589                    self.bias, x, self.weight, beta=self.beta, alpha=self.alpha
2590                )
2591
2592        dtypes = [torch.float32]
2593        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
2594            dtypes.append(torch.bfloat16)
2595        for dtype in dtypes:
2596            linear_op = (
2597                "mkl._mkl_linear"
2598                if dtype == torch.float32
2599                else "mkldnn._linear_pointwise"
2600            )
2601            for beta, alpha in zip([1.0, 0.1, 0.0], [1.0, 0.1, 1.0]):
2602                weight = torch.randn(64, 64, dtype=dtype)
2603                bias = torch.randn(64, dtype=dtype)
2604                mod = Mod(weight, bias, beta, alpha).to(dtype).eval()
2605                with torch.no_grad():
2606                    x = torch.randn(1, 64, dtype=dtype)
2607                    include_ops = []
2608                    exclude_ops = []
2609                    if (beta != 1.0 and beta != 0.0) or alpha != 1.0:
2610                        exclude_ops = [linear_op]
2611                    else:
2612                        include_ops = [linear_op]
2613                    self._test_code_common(mod, (x,), include_ops, exclude_ops)
2614
2615    @skipIfNoDynamoSupport
2616    @skipIfRocm
2617    def test_woq_int8(self):
2618        class M(torch.nn.Module):
2619            def forward(self, x, weight, scales):
2620                return torch.nn.functional.linear(x, weight.to(dtype=x.dtype)) * scales
2621
2622        mod = M().eval()
2623        x_shape = (1, 1, 256)
2624        w_shape = (12, 256)
2625        s_shape = 12
2626        x_strides = [
2627            (256, 256, 1),  # linear dispatching to mm
2628            (256, 32, 1),  # linear dispatching to bmm
2629        ]
2630        for x_stride in x_strides:
2631            x = torch.randn(x_shape, dtype=torch.bfloat16).as_strided(x_shape, x_stride)
2632            w = torch.randint(-128, 127, w_shape, dtype=torch.int8)
2633            s = torch.randn(s_shape, dtype=torch.bfloat16)
2634
2635            def matcher_check_fn():
2636                self.assertEqual(counters["inductor"]["woq_matcher_count"], 1)
2637
2638            self._test_common(
2639                mod,
2640                (x, w, s),
2641                matcher_check_fn=matcher_check_fn,
2642                check_quantization=False,
2643                atol=0.001,
2644                rtol=0.07,
2645            )
2646
2647
2648@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
2649class TestDynamicPatternMatcher(TestPatternMatcherBase):
2650    _test_conv_unary_cpu_base = TestPatternMatcher._test_conv_unary_cpu_base
2651    test_conv2d_unary_dynamic_shapes = TestPatternMatcher.test_conv2d_unary_cpu
2652    test_conv3d_unary_dynamic_shapes = TestPatternMatcher.test_conv3d_unary_cpu
2653    _test_conv_binary_base = TestPatternMatcher._test_conv_binary_base
2654    test_conv2d_binary_dynamic_shapes = TestPatternMatcher.test_conv2d_binary
2655    test_conv3d_binary_dynamic_shapes = TestPatternMatcher.test_conv3d_binary
2656    test_linear_unary_dynamic_shapes = TestPatternMatcher.test_linear_unary
2657
2658    def test_conv_transpose2d_dynamic_shapes(self):
2659        # We don't support conv_transpose2d for now.
2660        class M(torch.nn.Module):
2661            def __init__(self):
2662                super().__init__()
2663                self.conv_transpose2d = torch.nn.ConvTranspose2d(
2664                    3, 16, 3, stride=2, padding=1
2665                )
2666
2667            def forward(self, x):
2668                return self.conv_transpose2d(x)
2669
2670        x_shape = (1, 3, 28, 28)
2671        mod = M().eval()
2672        v = torch.randn(x_shape, dtype=torch.float32)
2673        self._test_common(mod, (v,), 0, 0)
2674
2675    def test_multi_linear_share_same_input_dynamic(self):
2676        # llama pattern.
2677        class M(torch.nn.Module):
2678            def __init__(
2679                self,
2680            ):
2681                super().__init__()
2682                self.w1 = torch.nn.Linear(16, 16, bias=False)
2683                self.w2 = torch.nn.Linear(16, 16, bias=False)
2684
2685            def forward(self, x):
2686                return F.silu(self.w1(x)) * F.relu(self.w2(x))
2687
2688        dtypes = []
2689        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
2690            dtypes.append(torch.bfloat16)
2691        if torch.ops.mkldnn._is_mkldnn_fp16_supported():
2692            dtypes.append(torch.float16)
2693        for dtype in dtypes:
2694            mod = M().to(dtype).eval()
2695            v = torch.randn(2, 4, 16).to(dtype)
2696            # 1. view(match_count=4, match_nodes=4).
2697            # 2. mm to packed linear(match_count=2, match_nodes=2).
2698            # 3. view+linear+view to linear(match_count=2, match_nodes=6).
2699            # 4. linear to linear+swish(match_count=1, match_nodes=2).
2700            # 5. linear to linear+relu(match_count=1, match_nodes=5).
2701
2702            match_count = 10
2703            match_nodes = 19
2704            self._test_common(mod, (v,), match_count, match_nodes, rtol=1e-2, atol=1e-2)
2705
2706    def test_qconv2d_maxpool2d_linear_dynamic_cpu(self, include_ops=None):
2707        r"""
2708        This testcase will quantize a single Conv2d->Maxpool2d->Linear module
2709        with dynamic batch size input.
2710        """
2711
2712        class M(torch.nn.Module):
2713            def __init__(
2714                self,
2715                **kwargs,
2716            ):
2717                super().__init__()
2718                self.conv = torch.nn.Conv2d(
2719                    3, 16, (2, 2), stride=(1, 1), padding=(1, 1)
2720                )
2721                self.relu = torch.nn.ReLU()
2722                self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
2723                self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
2724                self.linear = torch.nn.Linear(16, 16)
2725
2726            def forward(self, x):
2727                temp = self.relu(self.conv(x))
2728                temp = self.maxpool2d(temp)
2729                temp = self.avgpool(temp)
2730                temp = torch.flatten(temp, 1)
2731                return self.linear(temp)
2732
2733        mod = M().eval()
2734        v = torch.randn((2, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1)
2735        if include_ops is None:
2736            include_ops = [
2737                "torch.ops.onednn.qconv2d_pointwise",
2738                "torch.ops.quantized.max_pool2d",
2739                "torch.ops.onednn.qlinear_pointwise",
2740            ]
2741        exclude_ops = []
2742        self._test_code_common(
2743            mod,
2744            (v,),
2745            include_ops,
2746            exclude_ops,
2747            check_quantization=True,
2748            check_dynamic=True,
2749        )
2750
2751    @skipIfNoDynamoSupport
2752    @skipIfNoONEDNN
2753    @skipIfRocm
2754    def test_qat_bn_conv2d(self):
2755        r"""
2756        This testcase will quantize a single BN Conv2d module with qat flow.
2757        """
2758
2759        class M(torch.nn.Module):
2760            def __init__(
2761                self,
2762            ):
2763                super().__init__()
2764                self.conv = torch.nn.Conv2d(3, 3, 3)
2765                self.bn1 = torch.nn.BatchNorm2d(3)
2766                self.bn2 = torch.nn.BatchNorm2d(3)
2767
2768            def forward(self, x):
2769                x = self.conv(self.bn1(x))
2770                return self.bn2(x)
2771
2772        mod = M().train()
2773        v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1)
2774
2775        def matcher_check_fn():
2776            self.assertEqual(
2777                counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1
2778            )
2779
2780        self._test_common(
2781            mod,
2782            (v,),
2783            check_quantization=True,
2784            is_qat=True,
2785            matcher_check_fn=matcher_check_fn,
2786        )
2787
2788    @skipIfNoDynamoSupport
2789    @skipIfNoONEDNN
2790    @skipIfRocm
2791    def test_q_attention_block(self):
2792        class SelfAttnLikeModule(torch.nn.Module):
2793            def __init__(
2794                self,
2795                input_dim,
2796                transpose_for_score=False,
2797                num_attention_heads=None,
2798                attention_head_size=None,
2799            ) -> None:
2800                super().__init__()
2801                self.input_dim = input_dim
2802                self.q_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
2803                self.k_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
2804                self.v_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
2805                self.softmax = torch.nn.Softmax(dim=-1)
2806                self.transpose_for_score = transpose_for_score
2807                if self.transpose_for_score:
2808                    assert num_attention_heads is not None
2809                    assert attention_head_size is not None
2810                    self.num_attention_heads = num_attention_heads
2811                    self.attention_head_size = attention_head_size
2812
2813            def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
2814                new_x_shape = x.size()[:-1] + (
2815                    self.num_attention_heads,
2816                    self.attention_head_size,
2817                )
2818                x = x.view(new_x_shape)
2819                return x.permute(0, 2, 1, 3)
2820
2821            def forward(self, x):
2822                q = self.q_proj(x)
2823                k = self.k_proj(x)
2824                v = self.v_proj(x)
2825                if self.transpose_for_score:
2826                    q = self.transpose_for_scores(q)
2827                    k = self.transpose_for_scores(k)
2828                    v = self.transpose_for_scores(v)
2829                scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5)
2830                attention = self.softmax(scores)
2831                weighted = torch.matmul(attention, v)
2832                return weighted
2833
2834        for annotate_matmul in [False, True]:
2835            mod = SelfAttnLikeModule(
2836                input_dim=64 * 16,
2837                transpose_for_score=True,
2838                num_attention_heads=16,
2839                attention_head_size=64,
2840            ).eval()
2841            v = torch.randn(2, 384, 1024)
2842
2843            def matcher_check_fn():
2844                self.assertEqual(
2845                    counters["inductor"]["qlinear_weight_prepack_matcher_count"], 3
2846                )
2847                self.assertEqual(
2848                    counters["inductor"]["qlinear_unary_matcher_count"],
2849                    3 if annotate_matmul else 0,
2850                )
2851
2852            quantizer = X86InductorQuantizer()
2853            quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
2854            if annotate_matmul:
2855                quantizer.set_function_type_qconfig(
2856                    torch.matmul, quantizer.get_global_quantization_config()
2857                )
2858
2859            self._test_common(
2860                mod,
2861                (v,),
2862                check_quantization=True,
2863                matcher_check_fn=matcher_check_fn,
2864                quantizer=quantizer,
2865            )
2866
2867
2868if __name__ == "__main__":
2869    if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available():
2870        run_tests()
2871