xref: /aosp_15_r20/external/pytorch/test/test_cuda_expandable_segments.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: cuda"]
2# run time cuda tests, but with the allocator using expandable segments
3
4import pathlib
5import sys
6
7from test_cuda import (  # noqa: F401
8    TestBlockStateAbsorption,
9    TestCuda,
10    TestCudaMallocAsync,
11)
12
13import torch
14from torch.testing._internal.common_cuda import IS_JETSON, IS_WINDOWS
15from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM
16
17
18REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
19sys.path.insert(0, str(REPO_ROOT))
20
21from tools.stats.import_test_stats import get_disabled_tests
22
23
24# Make sure to remove REPO_ROOT after import is done
25sys.path.remove(str(REPO_ROOT))
26
27if __name__ == "__main__":
28    if (
29        torch.cuda.is_available()
30        and not IS_JETSON
31        and not IS_WINDOWS
32        and not TEST_WITH_ROCM
33    ):
34        get_disabled_tests(".")
35
36        torch.cuda.memory._set_allocator_settings("expandable_segments:True")
37        TestCuda.expandable_segments = lambda _: True
38        TestBlockStateAbsorption.expandable_segments = lambda _: True
39
40        run_tests()
41