xref: /aosp_15_r20/external/pytorch/test/inductor/test_memory_planning.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2
3import sys
4import unittest
5
6from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, skipIfRocm
7from torch.testing._internal.inductor_utils import HAS_CUDA
8
9
10if IS_WINDOWS and IS_CI:
11    sys.stderr.write(
12        "Windows CI does not have necessary dependencies for test_memory_planning yet\n"
13    )
14    if __name__ == "__main__":
15        sys.exit(0)
16    raise unittest.SkipTest("requires sympy/functorch/filelock")  # noqa: F821
17
18import torch
19from torch._C import FileCheck
20from torch._dynamo.utils import same
21from torch._inductor import config
22from torch._inductor.test_case import run_tests, TestCase
23from torch._inductor.utils import run_and_get_cpp_code
24from torch.export import Dim
25from torch.utils._triton import has_triton
26
27
28@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
29@config.patch(memory_planning=True)
30class TestMemoryPlanning(TestCase):
31    def _generate(self, *, device):
32        """
33        Generate a simple test case that has multiple simultaneously-live intermediate tensors.
34        """
35
36        class Foo(torch.nn.Module):
37            def forward(self, x, y, z):
38                t0 = x.matmul(y)
39                t1 = x.matmul(z)
40                t0 = x.transpose(0, 1).matmul(t1)
41                t1 = x.matmul(t0)
42                return t0.sum() + t1.sum()
43
44        x = torch.randn((3, 2), device=device)
45        y = torch.randn((2, 4), device=device)
46        z = torch.randn((2, 3), device=device)
47        return (Foo(), (x, y, z))
48
49    def test_python_wrapper(self):
50        f, args = self._generate(device="cuda")
51        compiled = torch.compile(f, dynamic=True)
52        result, code = run_and_get_cpp_code(compiled, *args)
53
54        FileCheck().check(
55            "pool1 = empty_strided_cuda(((4*s0*s1) + (align(4*(s0*s0))), ), (1, )"
56        ).check_next(
57            "buf0 = alloc_from_pool(pool1, 0, torch.float32, (s0, s0), (s0, 1))"
58        ).check(
59            "buf1 = alloc_from_pool(pool1, align(4*(s0*s0)),"
60        ).run(
61            code
62        )
63        self.assertTrue(same(f(*args), result))
64
65    def test_cpp_wrapper(self):
66        f, args = self._generate(device="cuda")
67        compiled = torch.compile(f, dynamic=True)
68        with config.patch({"cpp_wrapper": True, "abi_compatible": False}):
69            result, code = run_and_get_cpp_code(compiled, *args)
70
71        FileCheck().check(
72            "pool1 = at::detail::empty_strided_cuda({(4L*s0*s1) + (align(4L*(static_cast<int64_t>(s0*s0)))), }, {1L, }"
73        ).check_next(
74            "auto buf0 = alloc_from_pool(pool1, 0, at::kFloat, {s0, s0}, {s0, 1L});"
75        ).check(
76            "auto buf1 = alloc_from_pool(pool1, align(4L*(static_cast<int64_t>(s0*s0))),"
77        ).run(
78            code
79        )
80        self.assertTrue(same(f(*args), result))
81
82    @skipIfRocm(msg="test_aot_inductor doesn't work on ROCm")
83    def test_abi_compatible(self):
84        try:
85            from .test_aot_inductor import AOTIRunnerUtil
86        except ImportError:
87            from test_aot_inductor import AOTIRunnerUtil
88
89        f, args = self._generate(device="cuda")
90        dim0_x = Dim("dim0_x", min=1, max=2048)
91        dynamic_shapes = ({0: dim0_x}, None, None)
92        with config.patch("abi_compatible", True):
93            result, code = run_and_get_cpp_code(
94                lambda: AOTIRunnerUtil.run(
95                    "cuda", f, args, dynamic_shapes=dynamic_shapes
96                )
97            )
98
99        FileCheck().check(
100            "int64_t int_array_2[] = {24L + (align(12L*s0)), };"
101        ).check_next("int64_t int_array_3[] = {1L, };").check_next(
102            "AtenTensorHandle pool1_handle;"
103        ).check_next(
104            "aoti_torch_empty_strided(1, int_array_2, int_array_3,"
105        ).check_next(
106            "RAIIAtenTensorHandle pool1(pool1_handle);"
107        ).check_next(
108            "int64_t int_array_4[] = {s0, 3L};"
109        ).check_next(
110            "int64_t int_array_5[] = {3L, 1L};"
111        ).check_next(
112            "AtenTensorHandle tmp_tensor_handle_1;"
113        ).check_next(
114            "aoti_torch__alloc_from_pool(pool1, 0"
115        ).run(
116            code
117        )
118        self.assertTrue(same(f(*args), result))
119
120
121if __name__ == "__main__":
122    if HAS_CUDA:
123        run_tests()
124