1# Owner(s): ["module: cpp-extensions"] 2 3import glob 4import os 5import re 6import shutil 7import subprocess 8import sys 9import tempfile 10import unittest 11import warnings 12 13import torch 14import torch.backends.cudnn 15import torch.multiprocessing as mp 16import torch.testing._internal.common_utils as common 17import torch.utils.cpp_extension 18from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN 19from torch.testing._internal.common_utils import gradcheck 20from torch.utils.cpp_extension import ( 21 _TORCH_PATH, 22 check_compiler_is_gcc, 23 CUDA_HOME, 24 get_cxx_compiler, 25 remove_extension_h_precompiler_headers, 26 ROCM_HOME, 27) 28 29 30# define TEST_ROCM before changing TEST_CUDA 31TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None 32TEST_CUDA = TEST_CUDA and CUDA_HOME is not None 33TEST_MPS = torch.backends.mps.is_available() 34IS_WINDOWS = sys.platform == "win32" 35IS_LINUX = sys.platform.startswith("linux") 36 37 38def remove_build_path(): 39 default_build_root = torch.utils.cpp_extension.get_default_build_root() 40 if os.path.exists(default_build_root): 41 if IS_WINDOWS: 42 # rmtree returns permission error: [WinError 5] Access is denied 43 # on Windows, this is a word-around 44 subprocess.run(["rm", "-rf", default_build_root], stdout=subprocess.PIPE) 45 else: 46 shutil.rmtree(default_build_root) 47 48 49# There's only one test that runs gracheck, run slow mode manually 50@torch.testing._internal.common_utils.markDynamoStrictTest 51class TestCppExtensionJIT(common.TestCase): 52 """Tests just-in-time cpp extensions. 53 Don't confuse this with the PyTorch JIT (aka TorchScript). 54 """ 55 56 def setUp(self): 57 super().setUp() 58 # cpp extensions use relative paths. Those paths are relative to 59 # this file, so we'll change the working directory temporarily 60 self.old_working_dir = os.getcwd() 61 os.chdir(os.path.dirname(os.path.abspath(__file__))) 62 63 def tearDown(self): 64 super().tearDown() 65 # return the working directory (see setUp) 66 os.chdir(self.old_working_dir) 67 68 @classmethod 69 def setUpClass(cls): 70 remove_build_path() 71 72 @classmethod 73 def tearDownClass(cls): 74 remove_build_path() 75 76 def test_jit_compile_extension(self): 77 module = torch.utils.cpp_extension.load( 78 name="jit_extension", 79 sources=[ 80 "cpp_extensions/jit_extension.cpp", 81 "cpp_extensions/jit_extension2.cpp", 82 ], 83 extra_include_paths=[ 84 "cpp_extensions", 85 "path / with spaces in it", 86 "path with quote'", 87 ], 88 extra_cflags=["-g"], 89 verbose=True, 90 ) 91 x = torch.randn(4, 4) 92 y = torch.randn(4, 4) 93 94 z = module.tanh_add(x, y) 95 self.assertEqual(z, x.tanh() + y.tanh()) 96 97 # Checking we can call a method defined not in the main C++ file. 98 z = module.exp_add(x, y) 99 self.assertEqual(z, x.exp() + y.exp()) 100 101 # Checking we can use this JIT-compiled class. 102 doubler = module.Doubler(2, 2) 103 self.assertIsNone(doubler.get().grad) 104 self.assertEqual(doubler.get().sum(), 4) 105 self.assertEqual(doubler.forward().sum(), 8) 106 107 @unittest.skipIf(not (TEST_CUDA or TEST_ROCM), "CUDA not found") 108 def test_jit_cuda_extension(self): 109 # NOTE: The name of the extension must equal the name of the module. 110 module = torch.utils.cpp_extension.load( 111 name="torch_test_cuda_extension", 112 sources=[ 113 "cpp_extensions/cuda_extension.cpp", 114 "cpp_extensions/cuda_extension.cu", 115 ], 116 extra_cuda_cflags=["-O2"], 117 verbose=True, 118 keep_intermediates=False, 119 ) 120 121 x = torch.zeros(100, device="cuda", dtype=torch.float32) 122 y = torch.zeros(100, device="cuda", dtype=torch.float32) 123 124 z = module.sigmoid_add(x, y).cpu() 125 126 # 2 * sigmoid(0) = 2 * 0.5 = 1 127 self.assertEqual(z, torch.ones_like(z)) 128 129 @unittest.skipIf(not TEST_MPS, "MPS not found") 130 def test_mps_extension(self): 131 module = torch.utils.cpp_extension.load( 132 name="torch_test_mps_extension", 133 sources=[ 134 "cpp_extensions/mps_extension.mm", 135 ], 136 verbose=True, 137 keep_intermediates=False, 138 ) 139 140 tensor_length = 100000 141 x = torch.randn(tensor_length, device="cpu", dtype=torch.float32) 142 y = torch.randn(tensor_length, device="cpu", dtype=torch.float32) 143 144 cpu_output = module.get_cpu_add_output(x, y) 145 mps_output = module.get_mps_add_output(x.to("mps"), y.to("mps")) 146 147 self.assertEqual(cpu_output, mps_output.to("cpu")) 148 149 def _run_jit_cuda_archflags(self, flags, expected): 150 # Compile an extension with given `flags` 151 def _check_cuobjdump_output(expected_values, is_ptx=False): 152 elf_or_ptx = "--list-ptx" if is_ptx else "--list-elf" 153 lib_ext = ".pyd" if IS_WINDOWS else ".so" 154 # Note, .extension name may include _v1, _v2, so first find exact name 155 ext_filename = glob.glob( 156 os.path.join(temp_dir, "cudaext_archflag*" + lib_ext) 157 )[0] 158 command = ["cuobjdump", elf_or_ptx, ext_filename] 159 p = subprocess.Popen( 160 command, stdout=subprocess.PIPE, stderr=subprocess.PIPE 161 ) 162 output, err = p.communicate() 163 output = output.decode("ascii") 164 err = err.decode("ascii") 165 166 if not p.returncode == 0 or not err == "": 167 raise AssertionError( 168 f"Flags: {flags}\nReturncode: {p.returncode}\nStderr: {err}\n" 169 f"Output: {output} " 170 ) 171 172 actual_arches = sorted(re.findall(r"sm_\d\d", output)) 173 expected_arches = sorted(["sm_" + xx for xx in expected_values]) 174 self.assertEqual( 175 actual_arches, 176 expected_arches, 177 msg=f"Flags: {flags}, Actual: {actual_arches}, Expected: {expected_arches}\n" 178 f"Stderr: {err}\nOutput: {output}", 179 ) 180 181 temp_dir = tempfile.mkdtemp() 182 old_envvar = os.environ.get("TORCH_CUDA_ARCH_LIST", None) 183 try: 184 os.environ["TORCH_CUDA_ARCH_LIST"] = flags 185 186 params = { 187 "name": "cudaext_archflags", 188 "sources": [ 189 "cpp_extensions/cuda_extension.cpp", 190 "cpp_extensions/cuda_extension.cu", 191 ], 192 "extra_cuda_cflags": ["-O2"], 193 "verbose": True, 194 "build_directory": temp_dir, 195 } 196 197 if IS_WINDOWS: 198 p = mp.Process(target=torch.utils.cpp_extension.load, kwargs=params) 199 200 # Compile and load the test CUDA arch in a different Python process to avoid 201 # polluting the current one and causes test_jit_cuda_extension to fail on 202 # Windows. There is no clear way to unload a module after it has been imported 203 # and torch.utils.cpp_extension.load builds and loads the module in one go. 204 # See https://github.com/pytorch/pytorch/issues/61655 for more details 205 p.start() 206 p.join() 207 else: 208 torch.utils.cpp_extension.load(**params) 209 210 # Expected output for --list-elf: 211 # ELF file 1: cudaext_archflags.1.sm_61.cubin 212 # ELF file 2: cudaext_archflags.2.sm_52.cubin 213 _check_cuobjdump_output(expected[0]) 214 if expected[1] is not None: 215 # Expected output for --list-ptx: 216 # PTX file 1: cudaext_archflags.1.sm_61.ptx 217 _check_cuobjdump_output(expected[1], is_ptx=True) 218 finally: 219 if IS_WINDOWS: 220 # rmtree returns permission error: [WinError 5] Access is denied 221 # on Windows, this is a word-around 222 subprocess.run(["rm", "-rf", temp_dir], stdout=subprocess.PIPE) 223 else: 224 shutil.rmtree(temp_dir) 225 226 if old_envvar is None: 227 os.environ.pop("TORCH_CUDA_ARCH_LIST") 228 else: 229 os.environ["TORCH_CUDA_ARCH_LIST"] = old_envvar 230 231 @unittest.skipIf(not TEST_CUDA, "CUDA not found") 232 @unittest.skipIf(TEST_ROCM, "disabled on rocm") 233 def test_jit_cuda_archflags(self): 234 # Test a number of combinations: 235 # - the default for the machine we're testing on 236 # - Separators, can be ';' (most common) or ' ' 237 # - Architecture names 238 # - With/without '+PTX' 239 240 n = torch.cuda.device_count() 241 capabilities = {torch.cuda.get_device_capability(i) for i in range(n)} 242 # expected values is length-2 tuple: (list of ELF, list of PTX) 243 # note: there should not be more than one PTX value 244 archflags = { 245 "": ( 246 [f"{capability[0]}{capability[1]}" for capability in capabilities], 247 None, 248 ), 249 "Maxwell+Tegra;6.1": (["53", "61"], None), 250 "Volta": (["70"], ["70"]), 251 } 252 archflags["7.5+PTX"] = (["75"], ["75"]) 253 archflags["5.0;6.0+PTX;7.0;7.5"] = (["50", "60", "70", "75"], ["60"]) 254 if int(torch.version.cuda.split(".")[0]) < 12: 255 # CUDA 12 drops compute capability < 5.0 256 archflags["Pascal 3.5"] = (["35", "60", "61"], None) 257 258 for flags, expected in archflags.items(): 259 try: 260 self._run_jit_cuda_archflags(flags, expected) 261 except RuntimeError as e: 262 # Using the device default (empty flags) may fail if the device is newer than the CUDA compiler 263 # This raises a RuntimeError with a specific message which we explicitly ignore here 264 if not flags and "Error building" in str(e): 265 pass 266 else: 267 raise 268 try: 269 torch.cuda.synchronize() 270 except RuntimeError: 271 # Ignore any error, e.g. unsupported PTX code on current device 272 # to avoid errors from here leaking into other tests 273 pass 274 275 @unittest.skipIf(not TEST_CUDNN, "CuDNN not found") 276 @unittest.skipIf(TEST_ROCM, "Not supported on ROCm") 277 def test_jit_cudnn_extension(self): 278 # implementation of CuDNN ReLU 279 if IS_WINDOWS: 280 extra_ldflags = ["cudnn.lib"] 281 else: 282 extra_ldflags = ["-lcudnn"] 283 module = torch.utils.cpp_extension.load( 284 name="torch_test_cudnn_extension", 285 sources=["cpp_extensions/cudnn_extension.cpp"], 286 extra_ldflags=extra_ldflags, 287 verbose=True, 288 with_cuda=True, 289 ) 290 291 x = torch.randn(100, device="cuda", dtype=torch.float32) 292 y = torch.zeros(100, device="cuda", dtype=torch.float32) 293 module.cudnn_relu(x, y) # y=relu(x) 294 self.assertEqual(torch.nn.functional.relu(x), y) 295 with self.assertRaisesRegex(RuntimeError, "same size"): 296 y_incorrect = torch.zeros(20, device="cuda", dtype=torch.float32) 297 module.cudnn_relu(x, y_incorrect) 298 299 def test_inline_jit_compile_extension_with_functions_as_list(self): 300 cpp_source = """ 301 torch::Tensor tanh_add(torch::Tensor x, torch::Tensor y) { 302 return x.tanh() + y.tanh(); 303 } 304 """ 305 306 module = torch.utils.cpp_extension.load_inline( 307 name="inline_jit_extension_with_functions_list", 308 cpp_sources=cpp_source, 309 functions="tanh_add", 310 verbose=True, 311 ) 312 313 self.assertEqual(module.tanh_add.__doc__.split("\n")[2], "tanh_add") 314 315 x = torch.randn(4, 4) 316 y = torch.randn(4, 4) 317 318 z = module.tanh_add(x, y) 319 self.assertEqual(z, x.tanh() + y.tanh()) 320 321 def test_inline_jit_compile_extension_with_functions_as_dict(self): 322 cpp_source = """ 323 torch::Tensor tanh_add(torch::Tensor x, torch::Tensor y) { 324 return x.tanh() + y.tanh(); 325 } 326 """ 327 328 module = torch.utils.cpp_extension.load_inline( 329 name="inline_jit_extension_with_functions_dict", 330 cpp_sources=cpp_source, 331 functions={"tanh_add": "Tanh and then sum :D"}, 332 verbose=True, 333 ) 334 335 self.assertEqual(module.tanh_add.__doc__.split("\n")[2], "Tanh and then sum :D") 336 337 def test_inline_jit_compile_extension_multiple_sources_and_no_functions(self): 338 cpp_source1 = """ 339 torch::Tensor sin_add(torch::Tensor x, torch::Tensor y) { 340 return x.sin() + y.sin(); 341 } 342 """ 343 344 cpp_source2 = """ 345 #include <torch/extension.h> 346 torch::Tensor sin_add(torch::Tensor x, torch::Tensor y); 347 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 348 m.def("sin_add", &sin_add, "sin(x) + sin(y)"); 349 } 350 """ 351 352 module = torch.utils.cpp_extension.load_inline( 353 name="inline_jit_extension", 354 cpp_sources=[cpp_source1, cpp_source2], 355 verbose=True, 356 ) 357 358 x = torch.randn(4, 4) 359 y = torch.randn(4, 4) 360 361 z = module.sin_add(x, y) 362 self.assertEqual(z, x.sin() + y.sin()) 363 364 @unittest.skip("Temporarily disabled") 365 @unittest.skipIf(not (TEST_CUDA or TEST_ROCM), "CUDA not found") 366 def test_inline_jit_compile_extension_cuda(self): 367 cuda_source = """ 368 __global__ void cos_add_kernel( 369 const float* __restrict__ x, 370 const float* __restrict__ y, 371 float* __restrict__ output, 372 const int size) { 373 const auto index = blockIdx.x * blockDim.x + threadIdx.x; 374 if (index < size) { 375 output[index] = __cosf(x[index]) + __cosf(y[index]); 376 } 377 } 378 379 torch::Tensor cos_add(torch::Tensor x, torch::Tensor y) { 380 auto output = torch::zeros_like(x); 381 const int threads = 1024; 382 const int blocks = (output.numel() + threads - 1) / threads; 383 cos_add_kernel<<<blocks, threads>>>(x.data<float>(), y.data<float>(), output.data<float>(), output.numel()); 384 return output; 385 } 386 """ 387 388 # Here, the C++ source need only declare the function signature. 389 cpp_source = "torch::Tensor cos_add(torch::Tensor x, torch::Tensor y);" 390 391 module = torch.utils.cpp_extension.load_inline( 392 name="inline_jit_extension_cuda", 393 cpp_sources=cpp_source, 394 cuda_sources=cuda_source, 395 functions=["cos_add"], 396 verbose=True, 397 ) 398 399 self.assertEqual(module.cos_add.__doc__.split("\n")[2], "cos_add") 400 401 x = torch.randn(4, 4, device="cuda", dtype=torch.float32) 402 y = torch.randn(4, 4, device="cuda", dtype=torch.float32) 403 404 z = module.cos_add(x, y) 405 self.assertEqual(z, x.cos() + y.cos()) 406 407 @unittest.skip("Temporarily disabled") 408 @unittest.skipIf(not (TEST_CUDA or TEST_ROCM), "CUDA not found") 409 def test_inline_jit_compile_custom_op_cuda(self): 410 cuda_source = """ 411 __global__ void cos_add_kernel( 412 const float* __restrict__ x, 413 const float* __restrict__ y, 414 float* __restrict__ output, 415 const int size) { 416 const auto index = blockIdx.x * blockDim.x + threadIdx.x; 417 if (index < size) { 418 output[index] = __cosf(x[index]) + __cosf(y[index]); 419 } 420 } 421 422 torch::Tensor cos_add(torch::Tensor x, torch::Tensor y) { 423 auto output = torch::zeros_like(x); 424 const int threads = 1024; 425 const int blocks = (output.numel() + threads - 1) / threads; 426 cos_add_kernel<<<blocks, threads>>>(x.data_ptr<float>(), y.data_ptr<float>(), output.data_ptr<float>(), output.numel()); 427 return output; 428 } 429 """ 430 431 # Here, the C++ source need only declare the function signature. 432 cpp_source = """ 433 #include <torch/library.h> 434 torch::Tensor cos_add(torch::Tensor x, torch::Tensor y); 435 436 TORCH_LIBRARY(inline_jit_extension_custom_op_cuda, m) { 437 m.def("cos_add", cos_add); 438 } 439 """ 440 441 torch.utils.cpp_extension.load_inline( 442 name="inline_jit_extension_custom_op_cuda", 443 cpp_sources=cpp_source, 444 cuda_sources=cuda_source, 445 verbose=True, 446 is_python_module=False, 447 ) 448 449 x = torch.randn(4, 4, device="cuda", dtype=torch.float32) 450 y = torch.randn(4, 4, device="cuda", dtype=torch.float32) 451 452 z = torch.ops.inline_jit_extension_custom_op_cuda.cos_add(x, y) 453 self.assertEqual(z, x.cos() + y.cos()) 454 455 def test_inline_jit_compile_extension_throws_when_functions_is_bad(self): 456 with self.assertRaises(ValueError): 457 torch.utils.cpp_extension.load_inline( 458 name="invalid_jit_extension", cpp_sources="", functions=5 459 ) 460 461 def test_lenient_flag_handling_in_jit_extensions(self): 462 cpp_source = """ 463 torch::Tensor tanh_add(torch::Tensor x, torch::Tensor y) { 464 return x.tanh() + y.tanh(); 465 } 466 """ 467 468 module = torch.utils.cpp_extension.load_inline( 469 name="lenient_flag_handling_extension", 470 cpp_sources=cpp_source, 471 functions="tanh_add", 472 extra_cflags=["-g\n\n", "-O0 -Wall"], 473 extra_include_paths=[" cpp_extensions\n"], 474 verbose=True, 475 ) 476 477 x = torch.zeros(100, dtype=torch.float32) 478 y = torch.zeros(100, dtype=torch.float32) 479 z = module.tanh_add(x, y).cpu() 480 self.assertEqual(z, x.tanh() + y.tanh()) 481 482 @unittest.skip("Temporarily disabled") 483 @unittest.skipIf(not (TEST_CUDA or TEST_ROCM), "CUDA not found") 484 def test_half_support(self): 485 """ 486 Checks for an issue with operator< ambiguity for half when certain 487 THC headers are included. 488 489 See https://github.com/pytorch/pytorch/pull/10301#issuecomment-416773333 490 for the corresponding issue. 491 """ 492 cuda_source = """ 493 template<typename T, typename U> 494 __global__ void half_test_kernel(const T* input, U* output) { 495 if (input[0] < input[1] || input[0] >= input[1]) { 496 output[0] = 123; 497 } 498 } 499 500 torch::Tensor half_test(torch::Tensor input) { 501 auto output = torch::empty(1, input.options().dtype(torch::kFloat)); 502 AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "half_test", [&] { 503 half_test_kernel<scalar_t><<<1, 1>>>( 504 input.data<scalar_t>(), 505 output.data<float>()); 506 }); 507 return output; 508 } 509 """ 510 511 module = torch.utils.cpp_extension.load_inline( 512 name="half_test_extension", 513 cpp_sources="torch::Tensor half_test(torch::Tensor input);", 514 cuda_sources=cuda_source, 515 functions=["half_test"], 516 verbose=True, 517 ) 518 519 x = torch.randn(3, device="cuda", dtype=torch.half) 520 result = module.half_test(x) 521 self.assertEqual(result[0], 123) 522 523 def test_reload_jit_extension(self): 524 def compile(code): 525 return torch.utils.cpp_extension.load_inline( 526 name="reloaded_jit_extension", 527 cpp_sources=code, 528 functions="f", 529 verbose=True, 530 ) 531 532 module = compile("int f() { return 123; }") 533 self.assertEqual(module.f(), 123) 534 535 module = compile("int f() { return 456; }") 536 self.assertEqual(module.f(), 456) 537 module = compile("int f() { return 456; }") 538 self.assertEqual(module.f(), 456) 539 540 module = compile("int f() { return 789; }") 541 self.assertEqual(module.f(), 789) 542 543 def test_cpp_frontend_module_has_same_output_as_python(self, dtype=torch.double): 544 extension = torch.utils.cpp_extension.load( 545 name="cpp_frontend_extension", 546 sources="cpp_extensions/cpp_frontend_extension.cpp", 547 verbose=True, 548 ) 549 550 input = torch.randn(2, 5, dtype=dtype) 551 cpp_linear = extension.Net(5, 2) 552 cpp_linear.to(dtype) 553 python_linear = torch.nn.Linear(5, 2).to(dtype) 554 555 # First make sure they have the same parameters 556 cpp_parameters = dict(cpp_linear.named_parameters()) 557 with torch.no_grad(): 558 python_linear.weight.copy_(cpp_parameters["fc.weight"]) 559 python_linear.bias.copy_(cpp_parameters["fc.bias"]) 560 561 cpp_output = cpp_linear.forward(input) 562 python_output = python_linear(input) 563 self.assertEqual(cpp_output, python_output) 564 565 cpp_output.sum().backward() 566 python_output.sum().backward() 567 568 for p in cpp_linear.parameters(): 569 self.assertFalse(p.grad is None) 570 571 self.assertEqual(cpp_parameters["fc.weight"].grad, python_linear.weight.grad) 572 self.assertEqual(cpp_parameters["fc.bias"].grad, python_linear.bias.grad) 573 574 def test_cpp_frontend_module_python_inter_op(self): 575 extension = torch.utils.cpp_extension.load( 576 name="cpp_frontend_extension", 577 sources="cpp_extensions/cpp_frontend_extension.cpp", 578 verbose=True, 579 ) 580 581 # Create a torch.nn.Module which uses the C++ module as a submodule. 582 class M(torch.nn.Module): 583 def __init__(self) -> None: 584 super().__init__() 585 self.x = torch.nn.Parameter(torch.tensor(1.0)) 586 self.net = extension.Net(3, 5) 587 588 def forward(self, input): 589 return self.net.forward(input) + self.x 590 591 net = extension.Net(5, 2) 592 net.double() 593 net.to(torch.get_default_dtype()) 594 self.assertEqual(str(net), "Net") 595 596 # Further embed the torch.nn.Module into a Sequential, and also add the 597 # C++ module as an element of the Sequential. 598 sequential = torch.nn.Sequential(M(), torch.nn.Tanh(), net, torch.nn.Sigmoid()) 599 600 input = torch.randn(2, 3) 601 # Try calling the module! 602 output = sequential.forward(input) 603 # The call operator is bound to forward too. 604 self.assertEqual(output, sequential(input)) 605 self.assertEqual(list(output.shape), [2, 2]) 606 607 # Do changes on the module hierarchy. 608 old_dtype = torch.get_default_dtype() 609 sequential.to(torch.float64) 610 sequential.to(torch.float32) 611 sequential.to(old_dtype) 612 self.assertEqual(sequential[2].parameters()[0].dtype, old_dtype) 613 614 # Make sure we can access these methods recursively. 615 self.assertEqual( 616 len(list(sequential.parameters())), len(net.parameters()) * 2 + 1 617 ) 618 self.assertEqual( 619 len(list(sequential.named_parameters())), 620 len(net.named_parameters()) * 2 + 1, 621 ) 622 self.assertEqual(len(list(sequential.buffers())), len(net.buffers()) * 2) 623 self.assertEqual(len(list(sequential.modules())), 8) 624 625 # Test clone() 626 net2 = net.clone() 627 self.assertEqual(len(net.parameters()), len(net2.parameters())) 628 self.assertEqual(len(net.buffers()), len(net2.buffers())) 629 self.assertEqual(len(net.modules()), len(net2.modules())) 630 631 # Try differentiating through the whole module. 632 for parameter in net.parameters(): 633 self.assertIsNone(parameter.grad) 634 output.sum().backward() 635 for parameter in net.parameters(): 636 self.assertFalse(parameter.grad is None) 637 self.assertGreater(parameter.grad.sum(), 0) 638 639 # Try calling zero_grad() 640 net.zero_grad() 641 for p in net.parameters(): 642 assert p.grad is None, "zero_grad defaults to setting grads to None" 643 644 # Test train(), eval(), training (a property) 645 self.assertTrue(net.training) 646 net.eval() 647 self.assertFalse(net.training) 648 net.train() 649 self.assertTrue(net.training) 650 net.eval() 651 652 # Try calling the additional methods we registered. 653 biased_input = torch.randn(4, 5) 654 output_before = net.forward(biased_input) 655 bias = net.get_bias().clone() 656 self.assertEqual(list(bias.shape), [2]) 657 net.set_bias(bias + 1) 658 self.assertEqual(net.get_bias(), bias + 1) 659 output_after = net.forward(biased_input) 660 661 self.assertNotEqual(output_before, output_after) 662 663 # Try accessing parameters 664 self.assertEqual(len(net.parameters()), 2) 665 np = net.named_parameters() 666 self.assertEqual(len(np), 2) 667 self.assertIn("fc.weight", np) 668 self.assertIn("fc.bias", np) 669 670 self.assertEqual(len(net.buffers()), 1) 671 nb = net.named_buffers() 672 self.assertEqual(len(nb), 1) 673 self.assertIn("buf", nb) 674 self.assertEqual(nb[0][1], torch.eye(5)) 675 676 def test_cpp_frontend_module_has_up_to_date_attributes(self): 677 extension = torch.utils.cpp_extension.load( 678 name="cpp_frontend_extension", 679 sources="cpp_extensions/cpp_frontend_extension.cpp", 680 verbose=True, 681 ) 682 683 net = extension.Net(5, 2) 684 685 self.assertEqual(len(net._parameters), 0) 686 net.add_new_parameter("foo", torch.eye(5)) 687 self.assertEqual(len(net._parameters), 1) 688 689 self.assertEqual(len(net._buffers), 1) 690 net.add_new_buffer("bar", torch.eye(5)) 691 self.assertEqual(len(net._buffers), 2) 692 693 self.assertEqual(len(net._modules), 1) 694 net.add_new_submodule("fc2") 695 self.assertEqual(len(net._modules), 2) 696 697 @unittest.skipIf(not (TEST_CUDA or TEST_ROCM), "CUDA not found") 698 def test_cpp_frontend_module_python_inter_op_with_cuda(self): 699 extension = torch.utils.cpp_extension.load( 700 name="cpp_frontend_extension", 701 sources="cpp_extensions/cpp_frontend_extension.cpp", 702 verbose=True, 703 ) 704 705 net = extension.Net(5, 2) 706 for p in net.parameters(): 707 self.assertTrue(p.device.type == "cpu") 708 cpu_parameters = [p.clone() for p in net.parameters()] 709 710 device = torch.device("cuda", 0) 711 net.to(device) 712 713 for i, p in enumerate(net.parameters()): 714 self.assertTrue(p.device.type == "cuda") 715 self.assertTrue(p.device.index == 0) 716 self.assertEqual(cpu_parameters[i], p) 717 718 net.cpu() 719 net.add_new_parameter("a", torch.eye(5)) 720 net.add_new_parameter("b", torch.eye(5)) 721 net.add_new_buffer("c", torch.eye(5)) 722 net.add_new_buffer("d", torch.eye(5)) 723 net.add_new_submodule("fc2") 724 net.add_new_submodule("fc3") 725 726 for p in net.parameters(): 727 self.assertTrue(p.device.type == "cpu") 728 729 net.cuda() 730 731 for p in net.parameters(): 732 self.assertTrue(p.device.type == "cuda") 733 734 def test_returns_shared_library_path_when_is_python_module_is_true(self): 735 source = """ 736 #include <torch/script.h> 737 torch::Tensor func(torch::Tensor x) { return x; } 738 static torch::RegisterOperators r("test::func", &func); 739 """ 740 torch.utils.cpp_extension.load_inline( 741 name="is_python_module", 742 cpp_sources=source, 743 functions="func", 744 verbose=True, 745 is_python_module=False, 746 ) 747 self.assertEqual(torch.ops.test.func(torch.eye(5)), torch.eye(5)) 748 749 def test_set_default_type_also_changes_aten_default_type(self): 750 module = torch.utils.cpp_extension.load_inline( 751 name="test_set_default_type", 752 cpp_sources="torch::Tensor get() { return torch::empty({}); }", 753 functions="get", 754 verbose=True, 755 ) 756 757 initial_default = torch.get_default_dtype() 758 try: 759 self.assertEqual(module.get().dtype, initial_default) 760 torch.set_default_dtype(torch.float64) 761 self.assertEqual(module.get().dtype, torch.float64) 762 torch.set_default_dtype(torch.float32) 763 self.assertEqual(module.get().dtype, torch.float32) 764 torch.set_default_dtype(torch.float16) 765 self.assertEqual(module.get().dtype, torch.float16) 766 finally: 767 torch.set_default_dtype(initial_default) 768 769 def test_compilation_error_formatting(self): 770 # Test that the missing-semicolon error message has linebreaks in it. 771 # This'll fail if the message has been munged into a single line. 772 # It's hard to write anything more specific as every compiler has it's own 773 # error formatting. 774 with self.assertRaises(RuntimeError) as e: 775 torch.utils.cpp_extension.load_inline( 776 name="test_compilation_error_formatting", 777 cpp_sources="int main() { return 0 }", 778 ) 779 pattern = r".*(\\n|\\r).*" 780 self.assertNotRegex(str(e), pattern) 781 782 def test_warning(self): 783 # Note: the module created from this source will include the py::key_error 784 # symbol. But because of visibility and the fact that it lives in a 785 # different compilation unit than pybind, this trips up ubsan even though 786 # it is fine. "ubsan.supp" thus needs to contain "vptr:warn_mod.so". 787 source = """ 788 // error_type: 789 // 0: no error 790 // 1: torch::TypeError 791 // 2: python_error() 792 // 3: py::error_already_set 793 at::Tensor foo(at::Tensor x, int error_type) { 794 std::ostringstream err_stream; 795 err_stream << "Error with " << x.type(); 796 797 TORCH_WARN(err_stream.str()); 798 if(error_type == 1) { 799 throw torch::TypeError(err_stream.str().c_str()); 800 } 801 if(error_type == 2) { 802 PyObject* obj = PyTuple_New(-1); 803 TORCH_CHECK(!obj); 804 // Pretend it was caught in a different thread and restored here 805 auto e = python_error(); 806 e.persist(); 807 e.restore(); 808 throw e; 809 } 810 if(error_type == 3) { 811 throw py::key_error(err_stream.str()); 812 } 813 return x.cos(); 814 } 815 """ 816 817 # Ensure double type for hard-coded c name below 818 t = torch.rand(2).double() 819 cpp_tensor_name = r"CPUDoubleType" 820 821 # Without error handling, the warnings cannot be catched 822 warn_mod = torch.utils.cpp_extension.load_inline( 823 name="warn_mod", 824 cpp_sources=[source], 825 functions=["foo"], 826 with_pytorch_error_handling=False, 827 ) 828 829 with warnings.catch_warnings(record=True) as w: 830 warn_mod.foo(t, 0) 831 self.assertEqual(len(w), 0) 832 833 with self.assertRaisesRegex(TypeError, t.type()): 834 warn_mod.foo(t, 1) 835 self.assertEqual(len(w), 0) 836 837 with self.assertRaisesRegex( 838 SystemError, "bad argument to internal function" 839 ): 840 warn_mod.foo(t, 2) 841 self.assertEqual(len(w), 0) 842 843 with self.assertRaisesRegex(KeyError, cpp_tensor_name): 844 warn_mod.foo(t, 3) 845 self.assertEqual(len(w), 0) 846 847 warn_mod = torch.utils.cpp_extension.load_inline( 848 name="warn_mod", 849 cpp_sources=[source], 850 functions=["foo"], 851 with_pytorch_error_handling=True, 852 ) 853 854 with warnings.catch_warnings(record=True) as w: 855 # Catched with no error should be detected 856 warn_mod.foo(t, 0) 857 self.assertEqual(len(w), 1) 858 859 # Catched with cpp error should also be detected 860 with self.assertRaisesRegex(TypeError, t.type()): 861 warn_mod.foo(t, 1) 862 self.assertEqual(len(w), 2) 863 864 # Catched with python error should also be detected 865 with self.assertRaisesRegex( 866 SystemError, "bad argument to internal function" 867 ): 868 warn_mod.foo(t, 2) 869 self.assertEqual(len(w), 3) 870 871 # Catched with pybind error should also be detected 872 # Note that there is no type name translation for pybind errors 873 with self.assertRaisesRegex(KeyError, cpp_tensor_name): 874 warn_mod.foo(t, 3) 875 self.assertEqual(len(w), 4) 876 877 # Make sure raising warnings are handled properly 878 with warnings.catch_warnings(record=True) as w: 879 warnings.simplefilter("error") 880 881 # No error, the warning should raise 882 with self.assertRaisesRegex(UserWarning, t.type()): 883 warn_mod.foo(t, 0) 884 self.assertEqual(len(w), 0) 885 886 # Another error happened, the warning is ignored 887 with self.assertRaisesRegex(TypeError, t.type()): 888 warn_mod.foo(t, 1) 889 self.assertEqual(len(w), 0) 890 891 def test_autograd_from_cpp(self): 892 source = """ 893 void run_back(at::Tensor x) { 894 x.backward({}); 895 } 896 897 void run_back_no_gil(at::Tensor x) { 898 pybind11::gil_scoped_release no_gil; 899 x.backward({}); 900 } 901 """ 902 903 class MyFn(torch.autograd.Function): 904 @staticmethod 905 def forward(ctx, x): 906 return x.clone() 907 908 @staticmethod 909 def backward(ctx, gx): 910 return gx 911 912 test_backward_deadlock = torch.utils.cpp_extension.load_inline( 913 name="test_backward_deadlock", 914 cpp_sources=[source], 915 functions=["run_back", "run_back_no_gil"], 916 ) 917 918 # This used to deadlock 919 inp = torch.rand(20, requires_grad=True) 920 loss = MyFn.apply(inp).sum() 921 with self.assertRaisesRegex( 922 RuntimeError, "The autograd engine was called while holding the GIL." 923 ): 924 test_backward_deadlock.run_back(loss) 925 926 inp = torch.rand(20, requires_grad=True) 927 loss = MyFn.apply(inp).sum() 928 test_backward_deadlock.run_back_no_gil(loss) 929 930 def test_custom_compound_op_autograd(self): 931 # Test that a custom compound op (i.e. a custom op that just calls other aten ops) 932 # correctly returns gradients of those other ops 933 934 source = """ 935 #include <torch/library.h> 936 torch::Tensor my_add(torch::Tensor x, torch::Tensor y) { 937 return x + y; 938 } 939 TORCH_LIBRARY(my, m) { 940 m.def("add", &my_add); 941 } 942 """ 943 944 torch.utils.cpp_extension.load_inline( 945 name="is_python_module", 946 cpp_sources=source, 947 verbose=True, 948 is_python_module=False, 949 ) 950 951 a = torch.randn(5, 5, requires_grad=True) 952 b = torch.randn(5, 5, requires_grad=True) 953 954 for fast_mode in (True, False): 955 gradcheck(torch.ops.my.add, [a, b], eps=1e-2, fast_mode=fast_mode) 956 957 def test_custom_functorch_error(self): 958 # Test that a custom C++ Function raises an error under functorch transforms 959 identity_m = torch.utils.cpp_extension.load( 960 name="identity", 961 sources=["cpp_extensions/identity.cpp"], 962 ) 963 964 t = torch.randn(3, requires_grad=True) 965 966 msg = r"cannot use C\+\+ torch::autograd::Function with functorch" 967 with self.assertRaisesRegex(RuntimeError, msg): 968 torch.func.vmap(identity_m.identity)(t) 969 970 with self.assertRaisesRegex(RuntimeError, msg): 971 torch.func.grad(identity_m.identity)(t) 972 973 def test_gen_extension_h_pch(self): 974 if not IS_LINUX: 975 return 976 977 source = """ 978 at::Tensor sin_add(at::Tensor x, at::Tensor y) { 979 return x.sin() + y.sin(); 980 } 981 """ 982 983 head_file_pch = os.path.join(_TORCH_PATH, "include", "torch", "extension.h.gch") 984 head_file_signature = os.path.join( 985 _TORCH_PATH, "include", "torch", "extension.h.sign" 986 ) 987 988 remove_extension_h_precompiler_headers() 989 pch_exist = os.path.exists(head_file_pch) 990 signature_exist = os.path.exists(head_file_signature) 991 self.assertEqual(pch_exist, False) 992 self.assertEqual(signature_exist, False) 993 994 torch.utils.cpp_extension.load_inline( 995 name="inline_extension_with_pch", 996 cpp_sources=[source], 997 functions=["sin_add"], 998 verbose=True, 999 use_pch=True, 1000 ) 1001 pch_exist = os.path.exists(head_file_pch) 1002 signature_exist = os.path.exists(head_file_signature) 1003 1004 compiler = get_cxx_compiler() 1005 if check_compiler_is_gcc(compiler): 1006 self.assertEqual(pch_exist, True) 1007 self.assertEqual(signature_exist, True) 1008 1009 1010if __name__ == "__main__": 1011 common.run_tests() 1012