xref: /aosp_15_r20/external/pytorch/tools/testing/test_selections.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import math
4import os
5import subprocess
6from pathlib import Path
7from typing import Callable, Sequence
8
9from tools.stats.import_test_stats import get_disabled_tests
10from tools.testing.test_run import ShardedTest, TestRun
11
12
13REPO_ROOT = Path(__file__).resolve().parent.parent.parent
14
15IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
16BUILD_ENVIRONMENT = os.getenv("BUILD_ENVIRONMENT", "")
17USE_3_PROCS = "sm86" in BUILD_ENVIRONMENT or "cuda" not in BUILD_ENVIRONMENT
18
19# NUM_PROCS_FOR_SHARDING_CALC must remain consistent across all shards of a job
20# to ensure that sharding is consistent, NUM_PROCS is the actual number of procs
21# used to run tests.  If they are not equal, the only consequence should be
22# unequal shards.
23IS_ROCM = os.path.exists("/opt/rocm")
24NUM_PROCS = 1 if IS_MEM_LEAK_CHECK else 3 if USE_3_PROCS else 2
25NUM_PROCS_FOR_SHARDING_CALC = NUM_PROCS if not IS_ROCM or IS_MEM_LEAK_CHECK else 2
26THRESHOLD = 60 * 10  # 10 minutes
27
28# See Note [ROCm parallel CI testing]
29# Special logic for ROCm GHA runners to query number of GPUs available.
30# torch.version.hip was not available to check if this was a ROCm self-hosted runner.
31# Must check for ROCm runner in another way. We look for /opt/rocm directory.
32if IS_ROCM and not IS_MEM_LEAK_CHECK:
33    try:
34        # This is the same logic used in GHA health check, see .github/templates/common.yml.j2
35        lines = (
36            subprocess.check_output(["rocminfo"], encoding="ascii").strip().split("\n")
37        )
38        count = 0
39        for line in lines:
40            if " gfx" in line:
41                count += 1
42        assert count > 0  # there must be at least 1 GPU
43        # Limiting to 8 GPUs(PROCS)
44        NUM_PROCS = min(count, 8)
45    except subprocess.CalledProcessError as e:
46        # The safe default for ROCm GHA runners is to run tests serially.
47        NUM_PROCS = 1
48
49
50class ShardJob:
51    def __init__(self) -> None:
52        self.serial: list[ShardedTest] = []
53        self.parallel: list[ShardedTest] = []
54
55    def get_total_time(self) -> float:
56        """Default is the value for which to substitute if a test has no time"""
57        procs = [0.0 for _ in range(NUM_PROCS_FOR_SHARDING_CALC)]
58        for test in self.parallel:
59            min_index = procs.index(min(procs))
60            procs[min_index] += test.get_time()
61        time = max(procs) + sum(test.get_time() for test in self.serial)
62        return time
63
64    def convert_to_tuple(self) -> tuple[float, list[ShardedTest]]:
65        return (self.get_total_time(), self.serial + self.parallel)
66
67
68def get_with_pytest_shard(
69    tests: Sequence[TestRun],
70    test_file_times: dict[str, float],
71    test_class_times: dict[str, dict[str, float]] | None,
72) -> list[ShardedTest]:
73    sharded_tests: list[ShardedTest] = []
74
75    for test in tests:
76        duration = get_duration(test, test_file_times, test_class_times or {})
77
78        if duration and duration > THRESHOLD:
79            num_shards = math.ceil(duration / THRESHOLD)
80            for i in range(num_shards):
81                sharded_tests.append(
82                    ShardedTest(test, i + 1, num_shards, duration / num_shards)
83                )
84        else:
85            sharded_tests.append(ShardedTest(test, 1, 1, duration))
86    return sharded_tests
87
88
89def get_duration(
90    test: TestRun,
91    test_file_times: dict[str, float],
92    test_class_times: dict[str, dict[str, float]],
93) -> float | None:
94    """Calculate the time for a TestRun based on the given test_file_times and
95    test_class_times.  Returns None if the time is unknown."""
96    file_duration = test_file_times.get(test.test_file, None)
97    if test.is_full_file():
98        return file_duration
99
100    def get_duration_for_classes(
101        test_file: str, test_classes: frozenset[str]
102    ) -> float | None:
103        duration: float = 0
104
105        for test_class in test_classes:
106            class_duration = test_class_times.get(test_file, {}).get(test_class, None)
107            if class_duration is None:
108                return None
109            duration += class_duration
110        return duration
111
112    included = test.included()
113    excluded = test.excluded()
114    included_classes_duration = get_duration_for_classes(test.test_file, included)
115    excluded_classes_duration = get_duration_for_classes(test.test_file, excluded)
116
117    if included_classes_duration is None or excluded_classes_duration is None:
118        # Didn't get the time for all classes, so time is unknown
119        return None
120
121    if included:
122        return included_classes_duration
123    assert (
124        excluded
125    ), f"TestRun {test} is not full file but doesn't have included or excluded classes"
126    if file_duration is None:
127        return None
128    return file_duration - excluded_classes_duration
129
130
131def shard(
132    sharded_jobs: list[ShardJob],
133    pytest_sharded_tests: Sequence[ShardedTest],
134    estimated_time_limit: float | None = None,
135    serial: bool = False,
136) -> None:
137    # Modifies sharded_jobs in place
138    if len(sharded_jobs) == 0:
139        assert (
140            len(pytest_sharded_tests) == 0
141        ), "No shards provided but there are tests to shard"
142        return
143
144    round_robin_index = 0
145
146    def _get_min_sharded_job(
147        sharded_jobs: list[ShardJob], test: ShardedTest
148    ) -> ShardJob:
149        if test.time is None:
150            nonlocal round_robin_index
151            job = sharded_jobs[round_robin_index % len(sharded_jobs)]
152            round_robin_index += 1
153            return job
154        return min(sharded_jobs, key=lambda j: j.get_total_time())
155
156    def _shard_serial(
157        tests: Sequence[ShardedTest], sharded_jobs: list[ShardJob]
158    ) -> None:
159        assert estimated_time_limit is not None, "Estimated time limit must be provided"
160        new_sharded_jobs = sharded_jobs
161        for test in tests:
162            if (
163                len(sharded_jobs) > 1
164                and sharded_jobs[-1].get_total_time() > estimated_time_limit
165            ):
166                new_sharded_jobs = sharded_jobs[:-1]
167            min_sharded_job = _get_min_sharded_job(new_sharded_jobs, test)
168            min_sharded_job.serial.append(test)
169
170    def _shard_parallel(
171        tests: Sequence[ShardedTest], sharded_jobs: list[ShardJob]
172    ) -> None:
173        for test in tests:
174            min_sharded_job = _get_min_sharded_job(sharded_jobs, test)
175            min_sharded_job.parallel.append(test)
176
177    if serial:
178        _shard_serial(pytest_sharded_tests, sharded_jobs)
179    else:
180        _shard_parallel(pytest_sharded_tests, sharded_jobs)
181
182    return
183
184
185def calculate_shards(
186    num_shards: int,
187    tests: Sequence[TestRun],
188    test_file_times: dict[str, float],
189    test_class_times: dict[str, dict[str, float]] | None,
190    must_serial: Callable[[str], bool] | None = None,
191    sort_by_time: bool = True,
192) -> list[tuple[float, list[ShardedTest]]]:
193    must_serial = must_serial or (lambda x: True)
194    test_class_times = test_class_times or {}
195
196    # Divide tests into pytest shards
197    if sort_by_time:
198        known_tests = [
199            x
200            for x in tests
201            if get_duration(x, test_file_times, test_class_times) is not None
202        ]
203        unknown_tests = [x for x in tests if x not in known_tests]
204
205        pytest_sharded_tests = sorted(
206            get_with_pytest_shard(known_tests, test_file_times, test_class_times),
207            key=lambda j: j.get_time(),
208            reverse=True,
209        ) + get_with_pytest_shard(unknown_tests, test_file_times, test_class_times)
210    else:
211        pytest_sharded_tests = get_with_pytest_shard(
212            tests, test_file_times, test_class_times
213        )
214    del tests
215
216    serial_tests = [test for test in pytest_sharded_tests if must_serial(test.name)]
217    parallel_tests = [test for test in pytest_sharded_tests if test not in serial_tests]
218
219    serial_time = sum(test.get_time() for test in serial_tests)
220    parallel_time = sum(test.get_time() for test in parallel_tests)
221    total_time = serial_time + parallel_time / NUM_PROCS_FOR_SHARDING_CALC
222    estimated_time_per_shard = total_time / num_shards
223    # Separate serial tests from parallel tests as much as possible to maximize
224    # parallelism by putting all the serial tests on the first num_serial_shards
225    # shards. The estimated_time_limit is the estimated time it should take for
226    # the least filled serial shard. Ex if we have 8 min of serial tests, 20 min
227    # of parallel tests, 6 shards, and 2 procs per machine, we would expect each
228    # machine to take 3 min and should aim for 3 serial shards, with shards 1
229    # and 2 taking 3 min and shard 3 taking 2 min.  The estimated time limit
230    # would be 2 min. This ensures that the first few shard contains as many
231    # serial tests as possible and as few parallel tests as possible. The least
232    # filled/last (in the example, the 3rd) shard may contain a lot of both
233    # serial and parallel tests.
234    estimated_time_limit = 0.0
235    if estimated_time_per_shard != 0:
236        estimated_time_limit = serial_time % estimated_time_per_shard
237    if estimated_time_limit <= 0.01:
238        estimated_time_limit = estimated_time_per_shard
239    if total_time == 0:
240        num_serial_shards = num_shards
241    else:
242        num_serial_shards = max(math.ceil(serial_time / total_time * num_shards), 1)
243
244    sharded_jobs = [ShardJob() for _ in range(num_shards)]
245    shard(
246        sharded_jobs=sharded_jobs[:num_serial_shards],
247        pytest_sharded_tests=serial_tests,
248        estimated_time_limit=estimated_time_limit,
249        serial=True,
250    )
251    shard(
252        sharded_jobs=sharded_jobs,
253        pytest_sharded_tests=parallel_tests,
254        serial=False,
255    )
256
257    return [job.convert_to_tuple() for job in sharded_jobs]
258
259
260def get_test_case_configs(dirpath: str) -> None:
261    get_disabled_tests(dirpath=dirpath)
262