1# Owner(s): ["module: inductor"] 2 3import sys 4import unittest 5 6from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, skipIfRocm 7from torch.testing._internal.inductor_utils import HAS_CUDA 8 9 10if IS_WINDOWS and IS_CI: 11 sys.stderr.write( 12 "Windows CI does not have necessary dependencies for test_memory_planning yet\n" 13 ) 14 if __name__ == "__main__": 15 sys.exit(0) 16 raise unittest.SkipTest("requires sympy/functorch/filelock") # noqa: F821 17 18import torch 19from torch._C import FileCheck 20from torch._dynamo.utils import same 21from torch._inductor import config 22from torch._inductor.test_case import run_tests, TestCase 23from torch._inductor.utils import run_and_get_cpp_code 24from torch.export import Dim 25from torch.utils._triton import has_triton 26 27 28@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 29@config.patch(memory_planning=True) 30class TestMemoryPlanning(TestCase): 31 def _generate(self, *, device): 32 """ 33 Generate a simple test case that has multiple simultaneously-live intermediate tensors. 34 """ 35 36 class Foo(torch.nn.Module): 37 def forward(self, x, y, z): 38 t0 = x.matmul(y) 39 t1 = x.matmul(z) 40 t0 = x.transpose(0, 1).matmul(t1) 41 t1 = x.matmul(t0) 42 return t0.sum() + t1.sum() 43 44 x = torch.randn((3, 2), device=device) 45 y = torch.randn((2, 4), device=device) 46 z = torch.randn((2, 3), device=device) 47 return (Foo(), (x, y, z)) 48 49 def test_python_wrapper(self): 50 f, args = self._generate(device="cuda") 51 compiled = torch.compile(f, dynamic=True) 52 result, code = run_and_get_cpp_code(compiled, *args) 53 54 FileCheck().check( 55 "pool1 = empty_strided_cuda(((4*s0*s1) + (align(4*(s0*s0))), ), (1, )" 56 ).check_next( 57 "buf0 = alloc_from_pool(pool1, 0, torch.float32, (s0, s0), (s0, 1))" 58 ).check( 59 "buf1 = alloc_from_pool(pool1, align(4*(s0*s0))," 60 ).run( 61 code 62 ) 63 self.assertTrue(same(f(*args), result)) 64 65 def test_cpp_wrapper(self): 66 f, args = self._generate(device="cuda") 67 compiled = torch.compile(f, dynamic=True) 68 with config.patch({"cpp_wrapper": True, "abi_compatible": False}): 69 result, code = run_and_get_cpp_code(compiled, *args) 70 71 FileCheck().check( 72 "pool1 = at::detail::empty_strided_cuda({(4L*s0*s1) + (align(4L*(static_cast<int64_t>(s0*s0)))), }, {1L, }" 73 ).check_next( 74 "auto buf0 = alloc_from_pool(pool1, 0, at::kFloat, {s0, s0}, {s0, 1L});" 75 ).check( 76 "auto buf1 = alloc_from_pool(pool1, align(4L*(static_cast<int64_t>(s0*s0)))," 77 ).run( 78 code 79 ) 80 self.assertTrue(same(f(*args), result)) 81 82 @skipIfRocm(msg="test_aot_inductor doesn't work on ROCm") 83 def test_abi_compatible(self): 84 try: 85 from .test_aot_inductor import AOTIRunnerUtil 86 except ImportError: 87 from test_aot_inductor import AOTIRunnerUtil 88 89 f, args = self._generate(device="cuda") 90 dim0_x = Dim("dim0_x", min=1, max=2048) 91 dynamic_shapes = ({0: dim0_x}, None, None) 92 with config.patch("abi_compatible", True): 93 result, code = run_and_get_cpp_code( 94 lambda: AOTIRunnerUtil.run( 95 "cuda", f, args, dynamic_shapes=dynamic_shapes 96 ) 97 ) 98 99 FileCheck().check( 100 "int64_t int_array_2[] = {24L + (align(12L*s0)), };" 101 ).check_next("int64_t int_array_3[] = {1L, };").check_next( 102 "AtenTensorHandle pool1_handle;" 103 ).check_next( 104 "aoti_torch_empty_strided(1, int_array_2, int_array_3," 105 ).check_next( 106 "RAIIAtenTensorHandle pool1(pool1_handle);" 107 ).check_next( 108 "int64_t int_array_4[] = {s0, 3L};" 109 ).check_next( 110 "int64_t int_array_5[] = {3L, 1L};" 111 ).check_next( 112 "AtenTensorHandle tmp_tensor_handle_1;" 113 ).check_next( 114 "aoti_torch__alloc_from_pool(pool1, 0" 115 ).run( 116 code 117 ) 118 self.assertTrue(same(f(*args), result)) 119 120 121if __name__ == "__main__": 122 if HAS_CUDA: 123 run_tests() 124