xref: /aosp_15_r20/external/pytorch/test/inductor/test_cpu_cpp_wrapper.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: cpu inductor"]
2import sys
3import unittest
4from typing import NamedTuple
5
6import torch
7from torch._inductor import config
8from torch._inductor.test_case import TestCase as InductorTestCase
9from torch.testing._internal.common_device_type import (
10    get_desired_device_type_test_bases,
11)
12from torch.testing._internal.common_utils import (
13    IS_MACOS,
14    IS_WINDOWS,
15    slowTest,
16    TEST_WITH_ROCM,
17)
18from torch.testing._internal.inductor_utils import HAS_CPU
19
20
21try:
22    try:
23        from . import (
24            test_cpu_repro,
25            test_cpu_select_algorithm,
26            test_mkldnn_pattern_matcher,
27            test_torchinductor,
28            test_torchinductor_dynamic_shapes,
29        )
30    except ImportError:
31        import test_cpu_repro
32        import test_cpu_select_algorithm
33        import test_mkldnn_pattern_matcher
34        import test_torchinductor
35        import test_torchinductor_dynamic_shapes
36except unittest.SkipTest:
37    if __name__ == "__main__":
38        sys.exit(0)
39    raise
40
41
42_desired_test_bases = get_desired_device_type_test_bases()
43RUN_CPU = (
44    HAS_CPU
45    and any(getattr(x, "device_type", "") == "cpu" for x in _desired_test_bases)
46    and not IS_MACOS
47)
48
49
50class CppWrapperTemplate:
51    pass
52
53
54class TestCppWrapper(InductorTestCase):
55    device = "cpu"
56
57
58class DynamicShapesCppWrapperCpuTests(InductorTestCase):
59    device = "cpu"
60
61
62test_failures_cpp_wrapper = {
63    # conv2d will fallback for dynamic shapes; the fallback path is not yet supported
64    "test_conv2d_unary_cpu_dynamic_shapes": test_torchinductor.TestFailure(
65        ("cpp_wrapper",), is_skip=True
66    ),
67    "test_conv2d_binary_inplace_fusion_failed_cpu_dynamic_shapes": test_torchinductor.TestFailure(
68        ("cpp_wrapper",), is_skip=True
69    ),
70    "test_conv2d_binary_inplace_fusion_pass_cpu_dynamic_shapes": test_torchinductor.TestFailure(
71        ("cpp_wrapper",), is_skip=True
72    ),
73    # aten._native_multi_head_attention.default is not yet supported for dynamic shapes
74    "test_multihead_attention_cpu_dynamic_shapes": test_torchinductor.TestFailure(
75        ("cpp_wrapper",), is_skip=True
76    ),
77}
78if TEST_WITH_ROCM:
79    test_failures_cpp_wrapper.update(
80        {
81            "test_linear_packed": test_torchinductor.TestFailure(
82                ("cpp_wrapper"), is_skip=True
83            ),
84            "test_linear_packed_dynamic_shapes": test_torchinductor.TestFailure(
85                ("cpp_wrapper"), is_skip=True
86            ),
87        }
88    )
89if config.abi_compatible:
90    xfail_list = [
91        "test_lstm_packed_change_input_sizes_cpu",
92        *[
93            func
94            for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU())
95            if func.startswith("test_linear_with_pointwise")
96        ],
97    ]
98    for test_name in xfail_list:
99        test_failures_cpp_wrapper[test_name] = test_torchinductor.TestFailure(
100            ("cpp_wrapper",), is_skip=False
101        )
102        test_failures_cpp_wrapper[
103            f"{test_name}_dynamic_shapes"
104        ] = test_torchinductor.TestFailure(("cpp_wrapper",), is_skip=False)
105    skip_list = [
106        "test_multihead_attention_cpu",
107    ]
108    for test_name in skip_list:
109        test_failures_cpp_wrapper[test_name] = test_torchinductor.TestFailure(
110            ("cpp_wrapper",), is_skip=True
111        )
112        test_failures_cpp_wrapper[
113            f"{test_name}_dynamic_shapes"
114        ] = test_torchinductor.TestFailure(("cpp_wrapper",), is_skip=True)
115
116
117def make_test_case(
118    name,
119    device,
120    tests,
121    condition=True,
122    slow=False,
123    func_inputs=None,
124    code_string_count=None,
125):
126    test_name = f"{name}_{device}" if device else name
127    if code_string_count is None:
128        code_string_count = {}
129
130    func = getattr(tests, test_name)
131    assert callable(func), "not a callable"
132    func = slowTest(func) if slow else func
133
134    @config.patch(cpp_wrapper=True)
135    def fn(self):
136        tests.setUpClass()
137        tests.setUp()
138        try:
139            with torch._C._PreserveDispatchKeyGuard():
140                torch._C._dispatch_tls_set_dispatch_key_included(
141                    torch._C.DispatchKey.Dense, True
142                )
143
144                _, code = test_torchinductor.run_and_get_cpp_code(
145                    func, *func_inputs if func_inputs else []
146                )
147                self.assertEqual("CppWrapperCodeCache" in code, True)
148                self.assertTrue(
149                    all(
150                        code.count(string) == code_string_count[string]
151                        for string in code_string_count
152                    )
153                )
154        finally:
155            tests.tearDown()
156            tests.tearDownClass()
157
158    fn.__name__ = test_name
159    import copy
160
161    fn.__dict__ = copy.deepcopy(func.__dict__)
162    if condition:
163        setattr(
164            CppWrapperTemplate,
165            test_name,
166            fn,
167        )
168
169
170if RUN_CPU:
171
172    class BaseTest(NamedTuple):
173        name: str
174        device: str = "cpu"
175        tests: InductorTestCase = test_torchinductor.CpuTests()
176        condition: bool = True
177        slow: bool = False
178        func_inputs: list = None
179        code_string_count: dict = {}
180
181    for item in [
182        BaseTest("test_add_complex"),
183        BaseTest("test_add_complex4"),
184        BaseTest("test_as_strided"),  # buffer reuse
185        BaseTest("test_bernoulli1"),
186        BaseTest("test_bitwise"),  # int32
187        BaseTest("test_bmm1"),
188        BaseTest("test_bmm2"),
189        BaseTest("test_cat"),  # alias
190        BaseTest(
191            "test_conv2d_binary_inplace_fusion_failed",
192            "cpu",
193            test_mkldnn_pattern_matcher.TestPatternMatcher(),
194            condition=torch.backends.mkldnn.is_available(),
195            func_inputs=[
196                None
197                if config.abi_compatible
198                else ["op_mkldnn__convolution_pointwise_binary.call"],
199                None
200                if config.abi_compatible
201                else ["op_mkldnn__convolution_pointwise__binary.call"],
202            ],
203        ),
204        BaseTest(
205            "test_conv2d_binary_inplace_fusion_pass",
206            "cpu",
207            test_mkldnn_pattern_matcher.TestPatternMatcher(),
208            condition=torch.backends.mkldnn.is_available(),
209            func_inputs=[
210                None
211                if config.abi_compatible
212                else ["op_mkldnn__convolution_pointwise__binary.call"],
213                None
214                if config.abi_compatible
215                else ["op_mkldnn__convolution_pointwise_binary.call"],
216            ],
217        ),
218        BaseTest(
219            "test_conv2d_unary",
220            "cpu",
221            test_mkldnn_pattern_matcher.TestPatternMatcher(),
222            condition=torch.backends.mkldnn.is_available(),
223            slow=True,
224        ),
225        BaseTest("test_conv_transpose2d_packed", "cpu", test_cpu_repro.CPUReproTests()),
226        BaseTest("test_cumsum"),
227        BaseTest("test_custom_op_1"),
228        BaseTest("test_custom_op_2"),
229        BaseTest("test_custom_op_3"),
230        BaseTest("test_dtype_sympy_expr"),
231        BaseTest("test_embedding_bag"),  # test default FallbackKernel
232        BaseTest("test_index_put1"),
233        BaseTest("test_index_put_deterministic_fallback"),
234        BaseTest("test_adding_tensor_offsets"),
235        BaseTest("test_inductor_layout_optimization_input_mutations"),
236        BaseTest("test_int_div", "", test_cpu_repro.CPUReproTests()),
237        BaseTest("test_linear1"),
238        BaseTest("test_linear2"),
239        *[
240            BaseTest(func, "", test_cpu_select_algorithm.TestSelectAlgorithmCPU())
241            for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU())
242            if func.startswith("test_linear_with_pointwise")
243        ],
244        BaseTest("test_polar"),
245        BaseTest(
246            "test_linear_binary",
247            "",
248            test_mkldnn_pattern_matcher.TestPatternMatcher(),
249            torch.backends.mkldnn.is_available()
250            and torch.ops.mkldnn._is_mkldnn_bf16_supported(),
251        ),
252        BaseTest(
253            "test_linear_packed",
254            "",
255            test_cpu_repro.CPUReproTests(),
256            torch.backends.mkldnn.is_available()
257            and (
258                torch.ops.mkldnn._is_mkldnn_bf16_supported()
259                or torch.ops.mkldnn._is_mkldnn_fp16_supported()
260            ),
261        ),
262        BaseTest(
263            "test_lstm_packed_change_input_sizes",
264            "cpu",
265            test_cpu_repro.CPUReproTests(),
266            condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS,
267        ),
268        BaseTest("test_max_pool2d6"),
269        BaseTest("test_mm_views"),
270        BaseTest("test_multihead_attention", "cpu", test_cpu_repro.CPUReproTests()),
271        BaseTest(
272            "test_multi_threading",
273            condition=not IS_WINDOWS,
274            # Two threads compile, so we expect the output code to be printed twice.
275            code_string_count={"py::gil_scoped_release release;": 2},
276        ),
277        BaseTest("test_profiler_mark_wrapper_call"),
278        BaseTest(
279            "test_qconv2d",
280            "cpu",
281            test_mkldnn_pattern_matcher.TestPatternMatcher(),
282            condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS,
283        ),
284        BaseTest(
285            "test_qconv2d_relu",
286            "cpu",
287            test_mkldnn_pattern_matcher.TestPatternMatcher(),
288            condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS,
289        ),
290        BaseTest(
291            "test_qconv2d_add",
292            "cpu",
293            test_mkldnn_pattern_matcher.TestPatternMatcher(),
294            condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS,
295        ),
296        BaseTest(
297            "test_qconv2d_add_relu",
298            "cpu",
299            test_mkldnn_pattern_matcher.TestPatternMatcher(),
300            condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS,
301        ),
302        BaseTest(
303            "test_qconv2d_dequant_promotion",
304            "cpu",
305            test_mkldnn_pattern_matcher.TestPatternMatcher(),
306            condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS,
307        ),
308        BaseTest(
309            "test_qconv2d_maxpool2d_linear_dynamic",
310            "cpu",
311            test_mkldnn_pattern_matcher.TestDynamicPatternMatcher(),
312            condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS,
313            func_inputs=[
314                None
315                if config.abi_compatible
316                else [
317                    "op_onednn_qconv2d_pointwise_.call",
318                    "op_quantized_max_pool2d_.call",
319                    "op_onednn_qlinear_pointwise_tensor.call",
320                ],
321            ],
322        ),
323        BaseTest(
324            "test_qlinear",
325            "cpu",
326            test_mkldnn_pattern_matcher.TestPatternMatcher(),
327            condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS,
328        ),
329        BaseTest(
330            "test_qlinear_relu",
331            "cpu",
332            test_mkldnn_pattern_matcher.TestPatternMatcher(),
333            condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS,
334        ),
335        BaseTest(
336            "test_qlinear_gelu",
337            "cpu",
338            test_mkldnn_pattern_matcher.TestPatternMatcher(),
339            condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS,
340        ),
341        BaseTest(
342            "test_qlinear_add",
343            "cpu",
344            test_mkldnn_pattern_matcher.TestPatternMatcher(),
345            condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS,
346        ),
347        BaseTest(
348            "test_qlinear_add_relu",
349            "cpu",
350            test_mkldnn_pattern_matcher.TestPatternMatcher(),
351            condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS,
352        ),
353        BaseTest(
354            "test_qlinear_dequant_promotion",
355            "cpu",
356            test_mkldnn_pattern_matcher.TestPatternMatcher(),
357            condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS,
358        ),
359        BaseTest(
360            "test_dynamic_qlinear",
361            "cpu",
362            test_mkldnn_pattern_matcher.TestPatternMatcher(),
363            condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS,
364        ),
365        BaseTest(
366            "test_dynamic_qlinear_qat",
367            "cpu",
368            test_mkldnn_pattern_matcher.TestPatternMatcher(),
369            condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS,
370        ),
371        BaseTest("test_randint"),
372        BaseTest("test_randn_with_dtype_and_device"),
373        BaseTest("test_reduction1"),  # Reduction
374        BaseTest("test_relu"),  # multiple inputs
375        BaseTest("test_repeat_interleave", "", test_cpu_repro.CPUReproTests()),
376        BaseTest("test_scalar_input"),
377        BaseTest("test_scalar_output"),
378        BaseTest("test_scaled_dot_product_attention"),
379        BaseTest("test_scatter1"),
380        BaseTest("test_scatter2"),
381        BaseTest("test_scatter3"),
382        BaseTest("test_scatter4"),
383        BaseTest("test_scatter5"),
384        BaseTest("test_scatter6"),
385        BaseTest("test_scatter_reduce1"),
386        BaseTest("test_scatter_reduce2"),
387        BaseTest("test_scatter_reduce3"),
388        BaseTest("test_silu"),  # single input, single output
389        BaseTest("test_sort"),
390        BaseTest("test_sum_dtype"),  # float64
391        BaseTest("test_sum_int"),  # bool, int64, int8, uint8
392        BaseTest("test_tensor2"),  # constant input
393        BaseTest(
394            "test_transpose", code_string_count={".reset();": 2}
395        ),  # multiple outputs, buffer clear
396        BaseTest("test_view_as_complex"),
397        BaseTest("test_view_as_real"),
398    ]:
399        make_test_case(
400            item.name,
401            item.device,
402            item.tests,
403            item.condition,
404            item.slow,
405            item.func_inputs,
406            item.code_string_count,
407        )
408
409    test_torchinductor.copy_tests(
410        CppWrapperTemplate,
411        TestCppWrapper,
412        "cpp_wrapper",
413        test_failures_cpp_wrapper,
414    )
415
416    DynamicShapesCppWrapperTemplate = (
417        test_torchinductor_dynamic_shapes.make_dynamic_cls(CppWrapperTemplate)
418    )
419
420    test_torchinductor.copy_tests(
421        DynamicShapesCppWrapperTemplate,
422        DynamicShapesCppWrapperCpuTests,
423        "cpp_wrapper",
424        test_failures_cpp_wrapper,
425        xfail_prop="_expected_failure_dynamic_wrapper",
426    )
427
428
429if __name__ == "__main__":
430    from torch._inductor.test_case import run_tests
431
432    if RUN_CPU:
433        run_tests(needs="filelock")
434