1# Owner(s): ["oncall: cpu inductor"] 2import sys 3import unittest 4from typing import NamedTuple 5 6import torch 7from torch._inductor import config 8from torch._inductor.test_case import TestCase as InductorTestCase 9from torch.testing._internal.common_device_type import ( 10 get_desired_device_type_test_bases, 11) 12from torch.testing._internal.common_utils import ( 13 IS_MACOS, 14 IS_WINDOWS, 15 slowTest, 16 TEST_WITH_ROCM, 17) 18from torch.testing._internal.inductor_utils import HAS_CPU 19 20 21try: 22 try: 23 from . import ( 24 test_cpu_repro, 25 test_cpu_select_algorithm, 26 test_mkldnn_pattern_matcher, 27 test_torchinductor, 28 test_torchinductor_dynamic_shapes, 29 ) 30 except ImportError: 31 import test_cpu_repro 32 import test_cpu_select_algorithm 33 import test_mkldnn_pattern_matcher 34 import test_torchinductor 35 import test_torchinductor_dynamic_shapes 36except unittest.SkipTest: 37 if __name__ == "__main__": 38 sys.exit(0) 39 raise 40 41 42_desired_test_bases = get_desired_device_type_test_bases() 43RUN_CPU = ( 44 HAS_CPU 45 and any(getattr(x, "device_type", "") == "cpu" for x in _desired_test_bases) 46 and not IS_MACOS 47) 48 49 50class CppWrapperTemplate: 51 pass 52 53 54class TestCppWrapper(InductorTestCase): 55 device = "cpu" 56 57 58class DynamicShapesCppWrapperCpuTests(InductorTestCase): 59 device = "cpu" 60 61 62test_failures_cpp_wrapper = { 63 # conv2d will fallback for dynamic shapes; the fallback path is not yet supported 64 "test_conv2d_unary_cpu_dynamic_shapes": test_torchinductor.TestFailure( 65 ("cpp_wrapper",), is_skip=True 66 ), 67 "test_conv2d_binary_inplace_fusion_failed_cpu_dynamic_shapes": test_torchinductor.TestFailure( 68 ("cpp_wrapper",), is_skip=True 69 ), 70 "test_conv2d_binary_inplace_fusion_pass_cpu_dynamic_shapes": test_torchinductor.TestFailure( 71 ("cpp_wrapper",), is_skip=True 72 ), 73 # aten._native_multi_head_attention.default is not yet supported for dynamic shapes 74 "test_multihead_attention_cpu_dynamic_shapes": test_torchinductor.TestFailure( 75 ("cpp_wrapper",), is_skip=True 76 ), 77} 78if TEST_WITH_ROCM: 79 test_failures_cpp_wrapper.update( 80 { 81 "test_linear_packed": test_torchinductor.TestFailure( 82 ("cpp_wrapper"), is_skip=True 83 ), 84 "test_linear_packed_dynamic_shapes": test_torchinductor.TestFailure( 85 ("cpp_wrapper"), is_skip=True 86 ), 87 } 88 ) 89if config.abi_compatible: 90 xfail_list = [ 91 "test_lstm_packed_change_input_sizes_cpu", 92 *[ 93 func 94 for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU()) 95 if func.startswith("test_linear_with_pointwise") 96 ], 97 ] 98 for test_name in xfail_list: 99 test_failures_cpp_wrapper[test_name] = test_torchinductor.TestFailure( 100 ("cpp_wrapper",), is_skip=False 101 ) 102 test_failures_cpp_wrapper[ 103 f"{test_name}_dynamic_shapes" 104 ] = test_torchinductor.TestFailure(("cpp_wrapper",), is_skip=False) 105 skip_list = [ 106 "test_multihead_attention_cpu", 107 ] 108 for test_name in skip_list: 109 test_failures_cpp_wrapper[test_name] = test_torchinductor.TestFailure( 110 ("cpp_wrapper",), is_skip=True 111 ) 112 test_failures_cpp_wrapper[ 113 f"{test_name}_dynamic_shapes" 114 ] = test_torchinductor.TestFailure(("cpp_wrapper",), is_skip=True) 115 116 117def make_test_case( 118 name, 119 device, 120 tests, 121 condition=True, 122 slow=False, 123 func_inputs=None, 124 code_string_count=None, 125): 126 test_name = f"{name}_{device}" if device else name 127 if code_string_count is None: 128 code_string_count = {} 129 130 func = getattr(tests, test_name) 131 assert callable(func), "not a callable" 132 func = slowTest(func) if slow else func 133 134 @config.patch(cpp_wrapper=True) 135 def fn(self): 136 tests.setUpClass() 137 tests.setUp() 138 try: 139 with torch._C._PreserveDispatchKeyGuard(): 140 torch._C._dispatch_tls_set_dispatch_key_included( 141 torch._C.DispatchKey.Dense, True 142 ) 143 144 _, code = test_torchinductor.run_and_get_cpp_code( 145 func, *func_inputs if func_inputs else [] 146 ) 147 self.assertEqual("CppWrapperCodeCache" in code, True) 148 self.assertTrue( 149 all( 150 code.count(string) == code_string_count[string] 151 for string in code_string_count 152 ) 153 ) 154 finally: 155 tests.tearDown() 156 tests.tearDownClass() 157 158 fn.__name__ = test_name 159 import copy 160 161 fn.__dict__ = copy.deepcopy(func.__dict__) 162 if condition: 163 setattr( 164 CppWrapperTemplate, 165 test_name, 166 fn, 167 ) 168 169 170if RUN_CPU: 171 172 class BaseTest(NamedTuple): 173 name: str 174 device: str = "cpu" 175 tests: InductorTestCase = test_torchinductor.CpuTests() 176 condition: bool = True 177 slow: bool = False 178 func_inputs: list = None 179 code_string_count: dict = {} 180 181 for item in [ 182 BaseTest("test_add_complex"), 183 BaseTest("test_add_complex4"), 184 BaseTest("test_as_strided"), # buffer reuse 185 BaseTest("test_bernoulli1"), 186 BaseTest("test_bitwise"), # int32 187 BaseTest("test_bmm1"), 188 BaseTest("test_bmm2"), 189 BaseTest("test_cat"), # alias 190 BaseTest( 191 "test_conv2d_binary_inplace_fusion_failed", 192 "cpu", 193 test_mkldnn_pattern_matcher.TestPatternMatcher(), 194 condition=torch.backends.mkldnn.is_available(), 195 func_inputs=[ 196 None 197 if config.abi_compatible 198 else ["op_mkldnn__convolution_pointwise_binary.call"], 199 None 200 if config.abi_compatible 201 else ["op_mkldnn__convolution_pointwise__binary.call"], 202 ], 203 ), 204 BaseTest( 205 "test_conv2d_binary_inplace_fusion_pass", 206 "cpu", 207 test_mkldnn_pattern_matcher.TestPatternMatcher(), 208 condition=torch.backends.mkldnn.is_available(), 209 func_inputs=[ 210 None 211 if config.abi_compatible 212 else ["op_mkldnn__convolution_pointwise__binary.call"], 213 None 214 if config.abi_compatible 215 else ["op_mkldnn__convolution_pointwise_binary.call"], 216 ], 217 ), 218 BaseTest( 219 "test_conv2d_unary", 220 "cpu", 221 test_mkldnn_pattern_matcher.TestPatternMatcher(), 222 condition=torch.backends.mkldnn.is_available(), 223 slow=True, 224 ), 225 BaseTest("test_conv_transpose2d_packed", "cpu", test_cpu_repro.CPUReproTests()), 226 BaseTest("test_cumsum"), 227 BaseTest("test_custom_op_1"), 228 BaseTest("test_custom_op_2"), 229 BaseTest("test_custom_op_3"), 230 BaseTest("test_dtype_sympy_expr"), 231 BaseTest("test_embedding_bag"), # test default FallbackKernel 232 BaseTest("test_index_put1"), 233 BaseTest("test_index_put_deterministic_fallback"), 234 BaseTest("test_adding_tensor_offsets"), 235 BaseTest("test_inductor_layout_optimization_input_mutations"), 236 BaseTest("test_int_div", "", test_cpu_repro.CPUReproTests()), 237 BaseTest("test_linear1"), 238 BaseTest("test_linear2"), 239 *[ 240 BaseTest(func, "", test_cpu_select_algorithm.TestSelectAlgorithmCPU()) 241 for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU()) 242 if func.startswith("test_linear_with_pointwise") 243 ], 244 BaseTest("test_polar"), 245 BaseTest( 246 "test_linear_binary", 247 "", 248 test_mkldnn_pattern_matcher.TestPatternMatcher(), 249 torch.backends.mkldnn.is_available() 250 and torch.ops.mkldnn._is_mkldnn_bf16_supported(), 251 ), 252 BaseTest( 253 "test_linear_packed", 254 "", 255 test_cpu_repro.CPUReproTests(), 256 torch.backends.mkldnn.is_available() 257 and ( 258 torch.ops.mkldnn._is_mkldnn_bf16_supported() 259 or torch.ops.mkldnn._is_mkldnn_fp16_supported() 260 ), 261 ), 262 BaseTest( 263 "test_lstm_packed_change_input_sizes", 264 "cpu", 265 test_cpu_repro.CPUReproTests(), 266 condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS, 267 ), 268 BaseTest("test_max_pool2d6"), 269 BaseTest("test_mm_views"), 270 BaseTest("test_multihead_attention", "cpu", test_cpu_repro.CPUReproTests()), 271 BaseTest( 272 "test_multi_threading", 273 condition=not IS_WINDOWS, 274 # Two threads compile, so we expect the output code to be printed twice. 275 code_string_count={"py::gil_scoped_release release;": 2}, 276 ), 277 BaseTest("test_profiler_mark_wrapper_call"), 278 BaseTest( 279 "test_qconv2d", 280 "cpu", 281 test_mkldnn_pattern_matcher.TestPatternMatcher(), 282 condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS, 283 ), 284 BaseTest( 285 "test_qconv2d_relu", 286 "cpu", 287 test_mkldnn_pattern_matcher.TestPatternMatcher(), 288 condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS, 289 ), 290 BaseTest( 291 "test_qconv2d_add", 292 "cpu", 293 test_mkldnn_pattern_matcher.TestPatternMatcher(), 294 condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS, 295 ), 296 BaseTest( 297 "test_qconv2d_add_relu", 298 "cpu", 299 test_mkldnn_pattern_matcher.TestPatternMatcher(), 300 condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS, 301 ), 302 BaseTest( 303 "test_qconv2d_dequant_promotion", 304 "cpu", 305 test_mkldnn_pattern_matcher.TestPatternMatcher(), 306 condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS, 307 ), 308 BaseTest( 309 "test_qconv2d_maxpool2d_linear_dynamic", 310 "cpu", 311 test_mkldnn_pattern_matcher.TestDynamicPatternMatcher(), 312 condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS, 313 func_inputs=[ 314 None 315 if config.abi_compatible 316 else [ 317 "op_onednn_qconv2d_pointwise_.call", 318 "op_quantized_max_pool2d_.call", 319 "op_onednn_qlinear_pointwise_tensor.call", 320 ], 321 ], 322 ), 323 BaseTest( 324 "test_qlinear", 325 "cpu", 326 test_mkldnn_pattern_matcher.TestPatternMatcher(), 327 condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS, 328 ), 329 BaseTest( 330 "test_qlinear_relu", 331 "cpu", 332 test_mkldnn_pattern_matcher.TestPatternMatcher(), 333 condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS, 334 ), 335 BaseTest( 336 "test_qlinear_gelu", 337 "cpu", 338 test_mkldnn_pattern_matcher.TestPatternMatcher(), 339 condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS, 340 ), 341 BaseTest( 342 "test_qlinear_add", 343 "cpu", 344 test_mkldnn_pattern_matcher.TestPatternMatcher(), 345 condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS, 346 ), 347 BaseTest( 348 "test_qlinear_add_relu", 349 "cpu", 350 test_mkldnn_pattern_matcher.TestPatternMatcher(), 351 condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS, 352 ), 353 BaseTest( 354 "test_qlinear_dequant_promotion", 355 "cpu", 356 test_mkldnn_pattern_matcher.TestPatternMatcher(), 357 condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS, 358 ), 359 BaseTest( 360 "test_dynamic_qlinear", 361 "cpu", 362 test_mkldnn_pattern_matcher.TestPatternMatcher(), 363 condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS, 364 ), 365 BaseTest( 366 "test_dynamic_qlinear_qat", 367 "cpu", 368 test_mkldnn_pattern_matcher.TestPatternMatcher(), 369 condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS, 370 ), 371 BaseTest("test_randint"), 372 BaseTest("test_randn_with_dtype_and_device"), 373 BaseTest("test_reduction1"), # Reduction 374 BaseTest("test_relu"), # multiple inputs 375 BaseTest("test_repeat_interleave", "", test_cpu_repro.CPUReproTests()), 376 BaseTest("test_scalar_input"), 377 BaseTest("test_scalar_output"), 378 BaseTest("test_scaled_dot_product_attention"), 379 BaseTest("test_scatter1"), 380 BaseTest("test_scatter2"), 381 BaseTest("test_scatter3"), 382 BaseTest("test_scatter4"), 383 BaseTest("test_scatter5"), 384 BaseTest("test_scatter6"), 385 BaseTest("test_scatter_reduce1"), 386 BaseTest("test_scatter_reduce2"), 387 BaseTest("test_scatter_reduce3"), 388 BaseTest("test_silu"), # single input, single output 389 BaseTest("test_sort"), 390 BaseTest("test_sum_dtype"), # float64 391 BaseTest("test_sum_int"), # bool, int64, int8, uint8 392 BaseTest("test_tensor2"), # constant input 393 BaseTest( 394 "test_transpose", code_string_count={".reset();": 2} 395 ), # multiple outputs, buffer clear 396 BaseTest("test_view_as_complex"), 397 BaseTest("test_view_as_real"), 398 ]: 399 make_test_case( 400 item.name, 401 item.device, 402 item.tests, 403 item.condition, 404 item.slow, 405 item.func_inputs, 406 item.code_string_count, 407 ) 408 409 test_torchinductor.copy_tests( 410 CppWrapperTemplate, 411 TestCppWrapper, 412 "cpp_wrapper", 413 test_failures_cpp_wrapper, 414 ) 415 416 DynamicShapesCppWrapperTemplate = ( 417 test_torchinductor_dynamic_shapes.make_dynamic_cls(CppWrapperTemplate) 418 ) 419 420 test_torchinductor.copy_tests( 421 DynamicShapesCppWrapperTemplate, 422 DynamicShapesCppWrapperCpuTests, 423 "cpp_wrapper", 424 test_failures_cpp_wrapper, 425 xfail_prop="_expected_failure_dynamic_wrapper", 426 ) 427 428 429if __name__ == "__main__": 430 from torch._inductor.test_case import run_tests 431 432 if RUN_CPU: 433 run_tests(needs="filelock") 434