xref: /aosp_15_r20/external/pytorch/test/test_cpp_extensions_jit.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: cpp-extensions"]
2
3import glob
4import os
5import re
6import shutil
7import subprocess
8import sys
9import tempfile
10import unittest
11import warnings
12
13import torch
14import torch.backends.cudnn
15import torch.multiprocessing as mp
16import torch.testing._internal.common_utils as common
17import torch.utils.cpp_extension
18from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN
19from torch.testing._internal.common_utils import gradcheck
20from torch.utils.cpp_extension import (
21    _TORCH_PATH,
22    check_compiler_is_gcc,
23    CUDA_HOME,
24    get_cxx_compiler,
25    remove_extension_h_precompiler_headers,
26    ROCM_HOME,
27)
28
29
30# define TEST_ROCM before changing TEST_CUDA
31TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
32TEST_CUDA = TEST_CUDA and CUDA_HOME is not None
33TEST_MPS = torch.backends.mps.is_available()
34IS_WINDOWS = sys.platform == "win32"
35IS_LINUX = sys.platform.startswith("linux")
36
37
38def remove_build_path():
39    default_build_root = torch.utils.cpp_extension.get_default_build_root()
40    if os.path.exists(default_build_root):
41        if IS_WINDOWS:
42            # rmtree returns permission error: [WinError 5] Access is denied
43            # on Windows, this is a word-around
44            subprocess.run(["rm", "-rf", default_build_root], stdout=subprocess.PIPE)
45        else:
46            shutil.rmtree(default_build_root)
47
48
49# There's only one test that runs gracheck, run slow mode manually
50@torch.testing._internal.common_utils.markDynamoStrictTest
51class TestCppExtensionJIT(common.TestCase):
52    """Tests just-in-time cpp extensions.
53    Don't confuse this with the PyTorch JIT (aka TorchScript).
54    """
55
56    def setUp(self):
57        super().setUp()
58        # cpp extensions use relative paths. Those paths are relative to
59        # this file, so we'll change the working directory temporarily
60        self.old_working_dir = os.getcwd()
61        os.chdir(os.path.dirname(os.path.abspath(__file__)))
62
63    def tearDown(self):
64        super().tearDown()
65        # return the working directory (see setUp)
66        os.chdir(self.old_working_dir)
67
68    @classmethod
69    def setUpClass(cls):
70        remove_build_path()
71
72    @classmethod
73    def tearDownClass(cls):
74        remove_build_path()
75
76    def test_jit_compile_extension(self):
77        module = torch.utils.cpp_extension.load(
78            name="jit_extension",
79            sources=[
80                "cpp_extensions/jit_extension.cpp",
81                "cpp_extensions/jit_extension2.cpp",
82            ],
83            extra_include_paths=[
84                "cpp_extensions",
85                "path / with spaces in it",
86                "path with quote'",
87            ],
88            extra_cflags=["-g"],
89            verbose=True,
90        )
91        x = torch.randn(4, 4)
92        y = torch.randn(4, 4)
93
94        z = module.tanh_add(x, y)
95        self.assertEqual(z, x.tanh() + y.tanh())
96
97        # Checking we can call a method defined not in the main C++ file.
98        z = module.exp_add(x, y)
99        self.assertEqual(z, x.exp() + y.exp())
100
101        # Checking we can use this JIT-compiled class.
102        doubler = module.Doubler(2, 2)
103        self.assertIsNone(doubler.get().grad)
104        self.assertEqual(doubler.get().sum(), 4)
105        self.assertEqual(doubler.forward().sum(), 8)
106
107    @unittest.skipIf(not (TEST_CUDA or TEST_ROCM), "CUDA not found")
108    def test_jit_cuda_extension(self):
109        # NOTE: The name of the extension must equal the name of the module.
110        module = torch.utils.cpp_extension.load(
111            name="torch_test_cuda_extension",
112            sources=[
113                "cpp_extensions/cuda_extension.cpp",
114                "cpp_extensions/cuda_extension.cu",
115            ],
116            extra_cuda_cflags=["-O2"],
117            verbose=True,
118            keep_intermediates=False,
119        )
120
121        x = torch.zeros(100, device="cuda", dtype=torch.float32)
122        y = torch.zeros(100, device="cuda", dtype=torch.float32)
123
124        z = module.sigmoid_add(x, y).cpu()
125
126        # 2 * sigmoid(0) = 2 * 0.5 = 1
127        self.assertEqual(z, torch.ones_like(z))
128
129    @unittest.skipIf(not TEST_MPS, "MPS not found")
130    def test_mps_extension(self):
131        module = torch.utils.cpp_extension.load(
132            name="torch_test_mps_extension",
133            sources=[
134                "cpp_extensions/mps_extension.mm",
135            ],
136            verbose=True,
137            keep_intermediates=False,
138        )
139
140        tensor_length = 100000
141        x = torch.randn(tensor_length, device="cpu", dtype=torch.float32)
142        y = torch.randn(tensor_length, device="cpu", dtype=torch.float32)
143
144        cpu_output = module.get_cpu_add_output(x, y)
145        mps_output = module.get_mps_add_output(x.to("mps"), y.to("mps"))
146
147        self.assertEqual(cpu_output, mps_output.to("cpu"))
148
149    def _run_jit_cuda_archflags(self, flags, expected):
150        # Compile an extension with given `flags`
151        def _check_cuobjdump_output(expected_values, is_ptx=False):
152            elf_or_ptx = "--list-ptx" if is_ptx else "--list-elf"
153            lib_ext = ".pyd" if IS_WINDOWS else ".so"
154            # Note, .extension name may include _v1, _v2, so first find exact name
155            ext_filename = glob.glob(
156                os.path.join(temp_dir, "cudaext_archflag*" + lib_ext)
157            )[0]
158            command = ["cuobjdump", elf_or_ptx, ext_filename]
159            p = subprocess.Popen(
160                command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
161            )
162            output, err = p.communicate()
163            output = output.decode("ascii")
164            err = err.decode("ascii")
165
166            if not p.returncode == 0 or not err == "":
167                raise AssertionError(
168                    f"Flags: {flags}\nReturncode: {p.returncode}\nStderr: {err}\n"
169                    f"Output: {output} "
170                )
171
172            actual_arches = sorted(re.findall(r"sm_\d\d", output))
173            expected_arches = sorted(["sm_" + xx for xx in expected_values])
174            self.assertEqual(
175                actual_arches,
176                expected_arches,
177                msg=f"Flags: {flags},  Actual: {actual_arches},  Expected: {expected_arches}\n"
178                f"Stderr: {err}\nOutput: {output}",
179            )
180
181        temp_dir = tempfile.mkdtemp()
182        old_envvar = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
183        try:
184            os.environ["TORCH_CUDA_ARCH_LIST"] = flags
185
186            params = {
187                "name": "cudaext_archflags",
188                "sources": [
189                    "cpp_extensions/cuda_extension.cpp",
190                    "cpp_extensions/cuda_extension.cu",
191                ],
192                "extra_cuda_cflags": ["-O2"],
193                "verbose": True,
194                "build_directory": temp_dir,
195            }
196
197            if IS_WINDOWS:
198                p = mp.Process(target=torch.utils.cpp_extension.load, kwargs=params)
199
200                # Compile and load the test CUDA arch in a different Python process to avoid
201                # polluting the current one and causes test_jit_cuda_extension to fail on
202                # Windows. There is no clear way to unload a module after it has been imported
203                # and torch.utils.cpp_extension.load builds and loads the module in one go.
204                # See https://github.com/pytorch/pytorch/issues/61655 for more details
205                p.start()
206                p.join()
207            else:
208                torch.utils.cpp_extension.load(**params)
209
210            # Expected output for --list-elf:
211            #   ELF file    1: cudaext_archflags.1.sm_61.cubin
212            #   ELF file    2: cudaext_archflags.2.sm_52.cubin
213            _check_cuobjdump_output(expected[0])
214            if expected[1] is not None:
215                # Expected output for --list-ptx:
216                #   PTX file    1: cudaext_archflags.1.sm_61.ptx
217                _check_cuobjdump_output(expected[1], is_ptx=True)
218        finally:
219            if IS_WINDOWS:
220                # rmtree returns permission error: [WinError 5] Access is denied
221                # on Windows, this is a word-around
222                subprocess.run(["rm", "-rf", temp_dir], stdout=subprocess.PIPE)
223            else:
224                shutil.rmtree(temp_dir)
225
226            if old_envvar is None:
227                os.environ.pop("TORCH_CUDA_ARCH_LIST")
228            else:
229                os.environ["TORCH_CUDA_ARCH_LIST"] = old_envvar
230
231    @unittest.skipIf(not TEST_CUDA, "CUDA not found")
232    @unittest.skipIf(TEST_ROCM, "disabled on rocm")
233    def test_jit_cuda_archflags(self):
234        # Test a number of combinations:
235        #   - the default for the machine we're testing on
236        #   - Separators, can be ';' (most common) or ' '
237        #   - Architecture names
238        #   - With/without '+PTX'
239
240        n = torch.cuda.device_count()
241        capabilities = {torch.cuda.get_device_capability(i) for i in range(n)}
242        # expected values is length-2 tuple: (list of ELF, list of PTX)
243        # note: there should not be more than one PTX value
244        archflags = {
245            "": (
246                [f"{capability[0]}{capability[1]}" for capability in capabilities],
247                None,
248            ),
249            "Maxwell+Tegra;6.1": (["53", "61"], None),
250            "Volta": (["70"], ["70"]),
251        }
252        archflags["7.5+PTX"] = (["75"], ["75"])
253        archflags["5.0;6.0+PTX;7.0;7.5"] = (["50", "60", "70", "75"], ["60"])
254        if int(torch.version.cuda.split(".")[0]) < 12:
255            # CUDA 12 drops compute capability < 5.0
256            archflags["Pascal 3.5"] = (["35", "60", "61"], None)
257
258        for flags, expected in archflags.items():
259            try:
260                self._run_jit_cuda_archflags(flags, expected)
261            except RuntimeError as e:
262                # Using the device default (empty flags) may fail if the device is newer than the CUDA compiler
263                # This raises a RuntimeError with a specific message which we explicitly ignore here
264                if not flags and "Error building" in str(e):
265                    pass
266                else:
267                    raise
268            try:
269                torch.cuda.synchronize()
270            except RuntimeError:
271                # Ignore any error, e.g. unsupported PTX code on current device
272                # to avoid errors from here leaking into other tests
273                pass
274
275    @unittest.skipIf(not TEST_CUDNN, "CuDNN not found")
276    @unittest.skipIf(TEST_ROCM, "Not supported on ROCm")
277    def test_jit_cudnn_extension(self):
278        # implementation of CuDNN ReLU
279        if IS_WINDOWS:
280            extra_ldflags = ["cudnn.lib"]
281        else:
282            extra_ldflags = ["-lcudnn"]
283        module = torch.utils.cpp_extension.load(
284            name="torch_test_cudnn_extension",
285            sources=["cpp_extensions/cudnn_extension.cpp"],
286            extra_ldflags=extra_ldflags,
287            verbose=True,
288            with_cuda=True,
289        )
290
291        x = torch.randn(100, device="cuda", dtype=torch.float32)
292        y = torch.zeros(100, device="cuda", dtype=torch.float32)
293        module.cudnn_relu(x, y)  # y=relu(x)
294        self.assertEqual(torch.nn.functional.relu(x), y)
295        with self.assertRaisesRegex(RuntimeError, "same size"):
296            y_incorrect = torch.zeros(20, device="cuda", dtype=torch.float32)
297            module.cudnn_relu(x, y_incorrect)
298
299    def test_inline_jit_compile_extension_with_functions_as_list(self):
300        cpp_source = """
301        torch::Tensor tanh_add(torch::Tensor x, torch::Tensor y) {
302          return x.tanh() + y.tanh();
303        }
304        """
305
306        module = torch.utils.cpp_extension.load_inline(
307            name="inline_jit_extension_with_functions_list",
308            cpp_sources=cpp_source,
309            functions="tanh_add",
310            verbose=True,
311        )
312
313        self.assertEqual(module.tanh_add.__doc__.split("\n")[2], "tanh_add")
314
315        x = torch.randn(4, 4)
316        y = torch.randn(4, 4)
317
318        z = module.tanh_add(x, y)
319        self.assertEqual(z, x.tanh() + y.tanh())
320
321    def test_inline_jit_compile_extension_with_functions_as_dict(self):
322        cpp_source = """
323        torch::Tensor tanh_add(torch::Tensor x, torch::Tensor y) {
324          return x.tanh() + y.tanh();
325        }
326        """
327
328        module = torch.utils.cpp_extension.load_inline(
329            name="inline_jit_extension_with_functions_dict",
330            cpp_sources=cpp_source,
331            functions={"tanh_add": "Tanh and then sum :D"},
332            verbose=True,
333        )
334
335        self.assertEqual(module.tanh_add.__doc__.split("\n")[2], "Tanh and then sum :D")
336
337    def test_inline_jit_compile_extension_multiple_sources_and_no_functions(self):
338        cpp_source1 = """
339        torch::Tensor sin_add(torch::Tensor x, torch::Tensor y) {
340          return x.sin() + y.sin();
341        }
342        """
343
344        cpp_source2 = """
345        #include <torch/extension.h>
346        torch::Tensor sin_add(torch::Tensor x, torch::Tensor y);
347        PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
348          m.def("sin_add", &sin_add, "sin(x) + sin(y)");
349        }
350        """
351
352        module = torch.utils.cpp_extension.load_inline(
353            name="inline_jit_extension",
354            cpp_sources=[cpp_source1, cpp_source2],
355            verbose=True,
356        )
357
358        x = torch.randn(4, 4)
359        y = torch.randn(4, 4)
360
361        z = module.sin_add(x, y)
362        self.assertEqual(z, x.sin() + y.sin())
363
364    @unittest.skip("Temporarily disabled")
365    @unittest.skipIf(not (TEST_CUDA or TEST_ROCM), "CUDA not found")
366    def test_inline_jit_compile_extension_cuda(self):
367        cuda_source = """
368        __global__ void cos_add_kernel(
369            const float* __restrict__ x,
370            const float* __restrict__ y,
371            float* __restrict__ output,
372            const int size) {
373          const auto index = blockIdx.x * blockDim.x + threadIdx.x;
374          if (index < size) {
375            output[index] = __cosf(x[index]) + __cosf(y[index]);
376          }
377        }
378
379        torch::Tensor cos_add(torch::Tensor x, torch::Tensor y) {
380          auto output = torch::zeros_like(x);
381          const int threads = 1024;
382          const int blocks = (output.numel() + threads - 1) / threads;
383          cos_add_kernel<<<blocks, threads>>>(x.data<float>(), y.data<float>(), output.data<float>(), output.numel());
384          return output;
385        }
386        """
387
388        # Here, the C++ source need only declare the function signature.
389        cpp_source = "torch::Tensor cos_add(torch::Tensor x, torch::Tensor y);"
390
391        module = torch.utils.cpp_extension.load_inline(
392            name="inline_jit_extension_cuda",
393            cpp_sources=cpp_source,
394            cuda_sources=cuda_source,
395            functions=["cos_add"],
396            verbose=True,
397        )
398
399        self.assertEqual(module.cos_add.__doc__.split("\n")[2], "cos_add")
400
401        x = torch.randn(4, 4, device="cuda", dtype=torch.float32)
402        y = torch.randn(4, 4, device="cuda", dtype=torch.float32)
403
404        z = module.cos_add(x, y)
405        self.assertEqual(z, x.cos() + y.cos())
406
407    @unittest.skip("Temporarily disabled")
408    @unittest.skipIf(not (TEST_CUDA or TEST_ROCM), "CUDA not found")
409    def test_inline_jit_compile_custom_op_cuda(self):
410        cuda_source = """
411        __global__ void cos_add_kernel(
412            const float* __restrict__ x,
413            const float* __restrict__ y,
414            float* __restrict__ output,
415            const int size) {
416          const auto index = blockIdx.x * blockDim.x + threadIdx.x;
417          if (index < size) {
418            output[index] = __cosf(x[index]) + __cosf(y[index]);
419          }
420        }
421
422        torch::Tensor cos_add(torch::Tensor x, torch::Tensor y) {
423          auto output = torch::zeros_like(x);
424          const int threads = 1024;
425          const int blocks = (output.numel() + threads - 1) / threads;
426          cos_add_kernel<<<blocks, threads>>>(x.data_ptr<float>(), y.data_ptr<float>(), output.data_ptr<float>(), output.numel());
427          return output;
428        }
429        """
430
431        # Here, the C++ source need only declare the function signature.
432        cpp_source = """
433           #include <torch/library.h>
434           torch::Tensor cos_add(torch::Tensor x, torch::Tensor y);
435
436           TORCH_LIBRARY(inline_jit_extension_custom_op_cuda, m) {
437             m.def("cos_add", cos_add);
438           }
439        """
440
441        torch.utils.cpp_extension.load_inline(
442            name="inline_jit_extension_custom_op_cuda",
443            cpp_sources=cpp_source,
444            cuda_sources=cuda_source,
445            verbose=True,
446            is_python_module=False,
447        )
448
449        x = torch.randn(4, 4, device="cuda", dtype=torch.float32)
450        y = torch.randn(4, 4, device="cuda", dtype=torch.float32)
451
452        z = torch.ops.inline_jit_extension_custom_op_cuda.cos_add(x, y)
453        self.assertEqual(z, x.cos() + y.cos())
454
455    def test_inline_jit_compile_extension_throws_when_functions_is_bad(self):
456        with self.assertRaises(ValueError):
457            torch.utils.cpp_extension.load_inline(
458                name="invalid_jit_extension", cpp_sources="", functions=5
459            )
460
461    def test_lenient_flag_handling_in_jit_extensions(self):
462        cpp_source = """
463        torch::Tensor tanh_add(torch::Tensor x, torch::Tensor y) {
464          return x.tanh() + y.tanh();
465        }
466        """
467
468        module = torch.utils.cpp_extension.load_inline(
469            name="lenient_flag_handling_extension",
470            cpp_sources=cpp_source,
471            functions="tanh_add",
472            extra_cflags=["-g\n\n", "-O0 -Wall"],
473            extra_include_paths=["       cpp_extensions\n"],
474            verbose=True,
475        )
476
477        x = torch.zeros(100, dtype=torch.float32)
478        y = torch.zeros(100, dtype=torch.float32)
479        z = module.tanh_add(x, y).cpu()
480        self.assertEqual(z, x.tanh() + y.tanh())
481
482    @unittest.skip("Temporarily disabled")
483    @unittest.skipIf(not (TEST_CUDA or TEST_ROCM), "CUDA not found")
484    def test_half_support(self):
485        """
486        Checks for an issue with operator< ambiguity for half when certain
487        THC headers are included.
488
489        See https://github.com/pytorch/pytorch/pull/10301#issuecomment-416773333
490        for the corresponding issue.
491        """
492        cuda_source = """
493        template<typename T, typename U>
494        __global__ void half_test_kernel(const T* input, U* output) {
495            if (input[0] < input[1] || input[0] >= input[1]) {
496                output[0] = 123;
497            }
498        }
499
500        torch::Tensor half_test(torch::Tensor input) {
501            auto output = torch::empty(1, input.options().dtype(torch::kFloat));
502            AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "half_test", [&] {
503                half_test_kernel<scalar_t><<<1, 1>>>(
504                    input.data<scalar_t>(),
505                    output.data<float>());
506            });
507            return output;
508        }
509        """
510
511        module = torch.utils.cpp_extension.load_inline(
512            name="half_test_extension",
513            cpp_sources="torch::Tensor half_test(torch::Tensor input);",
514            cuda_sources=cuda_source,
515            functions=["half_test"],
516            verbose=True,
517        )
518
519        x = torch.randn(3, device="cuda", dtype=torch.half)
520        result = module.half_test(x)
521        self.assertEqual(result[0], 123)
522
523    def test_reload_jit_extension(self):
524        def compile(code):
525            return torch.utils.cpp_extension.load_inline(
526                name="reloaded_jit_extension",
527                cpp_sources=code,
528                functions="f",
529                verbose=True,
530            )
531
532        module = compile("int f() { return 123; }")
533        self.assertEqual(module.f(), 123)
534
535        module = compile("int f() { return 456; }")
536        self.assertEqual(module.f(), 456)
537        module = compile("int f() { return 456; }")
538        self.assertEqual(module.f(), 456)
539
540        module = compile("int f() { return 789; }")
541        self.assertEqual(module.f(), 789)
542
543    def test_cpp_frontend_module_has_same_output_as_python(self, dtype=torch.double):
544        extension = torch.utils.cpp_extension.load(
545            name="cpp_frontend_extension",
546            sources="cpp_extensions/cpp_frontend_extension.cpp",
547            verbose=True,
548        )
549
550        input = torch.randn(2, 5, dtype=dtype)
551        cpp_linear = extension.Net(5, 2)
552        cpp_linear.to(dtype)
553        python_linear = torch.nn.Linear(5, 2).to(dtype)
554
555        # First make sure they have the same parameters
556        cpp_parameters = dict(cpp_linear.named_parameters())
557        with torch.no_grad():
558            python_linear.weight.copy_(cpp_parameters["fc.weight"])
559            python_linear.bias.copy_(cpp_parameters["fc.bias"])
560
561        cpp_output = cpp_linear.forward(input)
562        python_output = python_linear(input)
563        self.assertEqual(cpp_output, python_output)
564
565        cpp_output.sum().backward()
566        python_output.sum().backward()
567
568        for p in cpp_linear.parameters():
569            self.assertFalse(p.grad is None)
570
571        self.assertEqual(cpp_parameters["fc.weight"].grad, python_linear.weight.grad)
572        self.assertEqual(cpp_parameters["fc.bias"].grad, python_linear.bias.grad)
573
574    def test_cpp_frontend_module_python_inter_op(self):
575        extension = torch.utils.cpp_extension.load(
576            name="cpp_frontend_extension",
577            sources="cpp_extensions/cpp_frontend_extension.cpp",
578            verbose=True,
579        )
580
581        # Create a torch.nn.Module which uses the C++ module as a submodule.
582        class M(torch.nn.Module):
583            def __init__(self) -> None:
584                super().__init__()
585                self.x = torch.nn.Parameter(torch.tensor(1.0))
586                self.net = extension.Net(3, 5)
587
588            def forward(self, input):
589                return self.net.forward(input) + self.x
590
591        net = extension.Net(5, 2)
592        net.double()
593        net.to(torch.get_default_dtype())
594        self.assertEqual(str(net), "Net")
595
596        # Further embed the torch.nn.Module into a Sequential, and also add the
597        # C++ module as an element of the Sequential.
598        sequential = torch.nn.Sequential(M(), torch.nn.Tanh(), net, torch.nn.Sigmoid())
599
600        input = torch.randn(2, 3)
601        # Try calling the module!
602        output = sequential.forward(input)
603        # The call operator is bound to forward too.
604        self.assertEqual(output, sequential(input))
605        self.assertEqual(list(output.shape), [2, 2])
606
607        # Do changes on the module hierarchy.
608        old_dtype = torch.get_default_dtype()
609        sequential.to(torch.float64)
610        sequential.to(torch.float32)
611        sequential.to(old_dtype)
612        self.assertEqual(sequential[2].parameters()[0].dtype, old_dtype)
613
614        # Make sure we can access these methods recursively.
615        self.assertEqual(
616            len(list(sequential.parameters())), len(net.parameters()) * 2 + 1
617        )
618        self.assertEqual(
619            len(list(sequential.named_parameters())),
620            len(net.named_parameters()) * 2 + 1,
621        )
622        self.assertEqual(len(list(sequential.buffers())), len(net.buffers()) * 2)
623        self.assertEqual(len(list(sequential.modules())), 8)
624
625        # Test clone()
626        net2 = net.clone()
627        self.assertEqual(len(net.parameters()), len(net2.parameters()))
628        self.assertEqual(len(net.buffers()), len(net2.buffers()))
629        self.assertEqual(len(net.modules()), len(net2.modules()))
630
631        # Try differentiating through the whole module.
632        for parameter in net.parameters():
633            self.assertIsNone(parameter.grad)
634        output.sum().backward()
635        for parameter in net.parameters():
636            self.assertFalse(parameter.grad is None)
637            self.assertGreater(parameter.grad.sum(), 0)
638
639        # Try calling zero_grad()
640        net.zero_grad()
641        for p in net.parameters():
642            assert p.grad is None, "zero_grad defaults to setting grads to None"
643
644        # Test train(), eval(), training (a property)
645        self.assertTrue(net.training)
646        net.eval()
647        self.assertFalse(net.training)
648        net.train()
649        self.assertTrue(net.training)
650        net.eval()
651
652        # Try calling the additional methods we registered.
653        biased_input = torch.randn(4, 5)
654        output_before = net.forward(biased_input)
655        bias = net.get_bias().clone()
656        self.assertEqual(list(bias.shape), [2])
657        net.set_bias(bias + 1)
658        self.assertEqual(net.get_bias(), bias + 1)
659        output_after = net.forward(biased_input)
660
661        self.assertNotEqual(output_before, output_after)
662
663        # Try accessing parameters
664        self.assertEqual(len(net.parameters()), 2)
665        np = net.named_parameters()
666        self.assertEqual(len(np), 2)
667        self.assertIn("fc.weight", np)
668        self.assertIn("fc.bias", np)
669
670        self.assertEqual(len(net.buffers()), 1)
671        nb = net.named_buffers()
672        self.assertEqual(len(nb), 1)
673        self.assertIn("buf", nb)
674        self.assertEqual(nb[0][1], torch.eye(5))
675
676    def test_cpp_frontend_module_has_up_to_date_attributes(self):
677        extension = torch.utils.cpp_extension.load(
678            name="cpp_frontend_extension",
679            sources="cpp_extensions/cpp_frontend_extension.cpp",
680            verbose=True,
681        )
682
683        net = extension.Net(5, 2)
684
685        self.assertEqual(len(net._parameters), 0)
686        net.add_new_parameter("foo", torch.eye(5))
687        self.assertEqual(len(net._parameters), 1)
688
689        self.assertEqual(len(net._buffers), 1)
690        net.add_new_buffer("bar", torch.eye(5))
691        self.assertEqual(len(net._buffers), 2)
692
693        self.assertEqual(len(net._modules), 1)
694        net.add_new_submodule("fc2")
695        self.assertEqual(len(net._modules), 2)
696
697    @unittest.skipIf(not (TEST_CUDA or TEST_ROCM), "CUDA not found")
698    def test_cpp_frontend_module_python_inter_op_with_cuda(self):
699        extension = torch.utils.cpp_extension.load(
700            name="cpp_frontend_extension",
701            sources="cpp_extensions/cpp_frontend_extension.cpp",
702            verbose=True,
703        )
704
705        net = extension.Net(5, 2)
706        for p in net.parameters():
707            self.assertTrue(p.device.type == "cpu")
708        cpu_parameters = [p.clone() for p in net.parameters()]
709
710        device = torch.device("cuda", 0)
711        net.to(device)
712
713        for i, p in enumerate(net.parameters()):
714            self.assertTrue(p.device.type == "cuda")
715            self.assertTrue(p.device.index == 0)
716            self.assertEqual(cpu_parameters[i], p)
717
718        net.cpu()
719        net.add_new_parameter("a", torch.eye(5))
720        net.add_new_parameter("b", torch.eye(5))
721        net.add_new_buffer("c", torch.eye(5))
722        net.add_new_buffer("d", torch.eye(5))
723        net.add_new_submodule("fc2")
724        net.add_new_submodule("fc3")
725
726        for p in net.parameters():
727            self.assertTrue(p.device.type == "cpu")
728
729        net.cuda()
730
731        for p in net.parameters():
732            self.assertTrue(p.device.type == "cuda")
733
734    def test_returns_shared_library_path_when_is_python_module_is_true(self):
735        source = """
736        #include <torch/script.h>
737        torch::Tensor func(torch::Tensor x) { return x; }
738        static torch::RegisterOperators r("test::func", &func);
739        """
740        torch.utils.cpp_extension.load_inline(
741            name="is_python_module",
742            cpp_sources=source,
743            functions="func",
744            verbose=True,
745            is_python_module=False,
746        )
747        self.assertEqual(torch.ops.test.func(torch.eye(5)), torch.eye(5))
748
749    def test_set_default_type_also_changes_aten_default_type(self):
750        module = torch.utils.cpp_extension.load_inline(
751            name="test_set_default_type",
752            cpp_sources="torch::Tensor get() { return torch::empty({}); }",
753            functions="get",
754            verbose=True,
755        )
756
757        initial_default = torch.get_default_dtype()
758        try:
759            self.assertEqual(module.get().dtype, initial_default)
760            torch.set_default_dtype(torch.float64)
761            self.assertEqual(module.get().dtype, torch.float64)
762            torch.set_default_dtype(torch.float32)
763            self.assertEqual(module.get().dtype, torch.float32)
764            torch.set_default_dtype(torch.float16)
765            self.assertEqual(module.get().dtype, torch.float16)
766        finally:
767            torch.set_default_dtype(initial_default)
768
769    def test_compilation_error_formatting(self):
770        # Test that the missing-semicolon error message has linebreaks in it.
771        # This'll fail if the message has been munged into a single line.
772        # It's hard to write anything more specific as every compiler has it's own
773        # error formatting.
774        with self.assertRaises(RuntimeError) as e:
775            torch.utils.cpp_extension.load_inline(
776                name="test_compilation_error_formatting",
777                cpp_sources="int main() { return 0 }",
778            )
779        pattern = r".*(\\n|\\r).*"
780        self.assertNotRegex(str(e), pattern)
781
782    def test_warning(self):
783        # Note: the module created from this source will include the py::key_error
784        # symbol. But because of visibility and the fact that it lives in a
785        # different compilation unit than pybind, this trips up ubsan even though
786        # it is fine. "ubsan.supp" thus needs to contain "vptr:warn_mod.so".
787        source = """
788        // error_type:
789        // 0: no error
790        // 1: torch::TypeError
791        // 2: python_error()
792        // 3: py::error_already_set
793        at::Tensor foo(at::Tensor x, int error_type) {
794            std::ostringstream err_stream;
795            err_stream << "Error with "  << x.type();
796
797            TORCH_WARN(err_stream.str());
798            if(error_type == 1) {
799                throw torch::TypeError(err_stream.str().c_str());
800            }
801            if(error_type == 2) {
802                PyObject* obj = PyTuple_New(-1);
803                TORCH_CHECK(!obj);
804                // Pretend it was caught in a different thread and restored here
805                auto e = python_error();
806                e.persist();
807                e.restore();
808                throw e;
809            }
810            if(error_type == 3) {
811                throw py::key_error(err_stream.str());
812            }
813            return x.cos();
814        }
815        """
816
817        # Ensure double type for hard-coded c name below
818        t = torch.rand(2).double()
819        cpp_tensor_name = r"CPUDoubleType"
820
821        # Without error handling, the warnings cannot be catched
822        warn_mod = torch.utils.cpp_extension.load_inline(
823            name="warn_mod",
824            cpp_sources=[source],
825            functions=["foo"],
826            with_pytorch_error_handling=False,
827        )
828
829        with warnings.catch_warnings(record=True) as w:
830            warn_mod.foo(t, 0)
831            self.assertEqual(len(w), 0)
832
833            with self.assertRaisesRegex(TypeError, t.type()):
834                warn_mod.foo(t, 1)
835            self.assertEqual(len(w), 0)
836
837            with self.assertRaisesRegex(
838                SystemError, "bad argument to internal function"
839            ):
840                warn_mod.foo(t, 2)
841            self.assertEqual(len(w), 0)
842
843            with self.assertRaisesRegex(KeyError, cpp_tensor_name):
844                warn_mod.foo(t, 3)
845            self.assertEqual(len(w), 0)
846
847        warn_mod = torch.utils.cpp_extension.load_inline(
848            name="warn_mod",
849            cpp_sources=[source],
850            functions=["foo"],
851            with_pytorch_error_handling=True,
852        )
853
854        with warnings.catch_warnings(record=True) as w:
855            # Catched with no error should be detected
856            warn_mod.foo(t, 0)
857            self.assertEqual(len(w), 1)
858
859            # Catched with cpp error should also be detected
860            with self.assertRaisesRegex(TypeError, t.type()):
861                warn_mod.foo(t, 1)
862            self.assertEqual(len(w), 2)
863
864            # Catched with python error should also be detected
865            with self.assertRaisesRegex(
866                SystemError, "bad argument to internal function"
867            ):
868                warn_mod.foo(t, 2)
869            self.assertEqual(len(w), 3)
870
871            # Catched with pybind error should also be detected
872            # Note that there is no type name translation for pybind errors
873            with self.assertRaisesRegex(KeyError, cpp_tensor_name):
874                warn_mod.foo(t, 3)
875            self.assertEqual(len(w), 4)
876
877        # Make sure raising warnings are handled properly
878        with warnings.catch_warnings(record=True) as w:
879            warnings.simplefilter("error")
880
881            # No error, the warning should raise
882            with self.assertRaisesRegex(UserWarning, t.type()):
883                warn_mod.foo(t, 0)
884            self.assertEqual(len(w), 0)
885
886            # Another error happened, the warning is ignored
887            with self.assertRaisesRegex(TypeError, t.type()):
888                warn_mod.foo(t, 1)
889            self.assertEqual(len(w), 0)
890
891    def test_autograd_from_cpp(self):
892        source = """
893        void run_back(at::Tensor x) {
894            x.backward({});
895        }
896
897        void run_back_no_gil(at::Tensor x) {
898            pybind11::gil_scoped_release no_gil;
899            x.backward({});
900        }
901        """
902
903        class MyFn(torch.autograd.Function):
904            @staticmethod
905            def forward(ctx, x):
906                return x.clone()
907
908            @staticmethod
909            def backward(ctx, gx):
910                return gx
911
912        test_backward_deadlock = torch.utils.cpp_extension.load_inline(
913            name="test_backward_deadlock",
914            cpp_sources=[source],
915            functions=["run_back", "run_back_no_gil"],
916        )
917
918        # This used to deadlock
919        inp = torch.rand(20, requires_grad=True)
920        loss = MyFn.apply(inp).sum()
921        with self.assertRaisesRegex(
922            RuntimeError, "The autograd engine was called while holding the GIL."
923        ):
924            test_backward_deadlock.run_back(loss)
925
926        inp = torch.rand(20, requires_grad=True)
927        loss = MyFn.apply(inp).sum()
928        test_backward_deadlock.run_back_no_gil(loss)
929
930    def test_custom_compound_op_autograd(self):
931        # Test that a custom compound op (i.e. a custom op that just calls other aten ops)
932        # correctly returns gradients of those other ops
933
934        source = """
935        #include <torch/library.h>
936        torch::Tensor my_add(torch::Tensor x, torch::Tensor y) {
937          return x + y;
938        }
939        TORCH_LIBRARY(my, m) {
940            m.def("add", &my_add);
941        }
942        """
943
944        torch.utils.cpp_extension.load_inline(
945            name="is_python_module",
946            cpp_sources=source,
947            verbose=True,
948            is_python_module=False,
949        )
950
951        a = torch.randn(5, 5, requires_grad=True)
952        b = torch.randn(5, 5, requires_grad=True)
953
954        for fast_mode in (True, False):
955            gradcheck(torch.ops.my.add, [a, b], eps=1e-2, fast_mode=fast_mode)
956
957    def test_custom_functorch_error(self):
958        # Test that a custom C++ Function raises an error under functorch transforms
959        identity_m = torch.utils.cpp_extension.load(
960            name="identity",
961            sources=["cpp_extensions/identity.cpp"],
962        )
963
964        t = torch.randn(3, requires_grad=True)
965
966        msg = r"cannot use C\+\+ torch::autograd::Function with functorch"
967        with self.assertRaisesRegex(RuntimeError, msg):
968            torch.func.vmap(identity_m.identity)(t)
969
970        with self.assertRaisesRegex(RuntimeError, msg):
971            torch.func.grad(identity_m.identity)(t)
972
973    def test_gen_extension_h_pch(self):
974        if not IS_LINUX:
975            return
976
977        source = """
978        at::Tensor sin_add(at::Tensor x, at::Tensor y) {
979            return x.sin() + y.sin();
980        }
981        """
982
983        head_file_pch = os.path.join(_TORCH_PATH, "include", "torch", "extension.h.gch")
984        head_file_signature = os.path.join(
985            _TORCH_PATH, "include", "torch", "extension.h.sign"
986        )
987
988        remove_extension_h_precompiler_headers()
989        pch_exist = os.path.exists(head_file_pch)
990        signature_exist = os.path.exists(head_file_signature)
991        self.assertEqual(pch_exist, False)
992        self.assertEqual(signature_exist, False)
993
994        torch.utils.cpp_extension.load_inline(
995            name="inline_extension_with_pch",
996            cpp_sources=[source],
997            functions=["sin_add"],
998            verbose=True,
999            use_pch=True,
1000        )
1001        pch_exist = os.path.exists(head_file_pch)
1002        signature_exist = os.path.exists(head_file_signature)
1003
1004        compiler = get_cxx_compiler()
1005        if check_compiler_is_gcc(compiler):
1006            self.assertEqual(pch_exist, True)
1007            self.assertEqual(signature_exist, True)
1008
1009
1010if __name__ == "__main__":
1011    common.run_tests()
1012