1from __future__ import annotations 2 3import math 4import os 5import subprocess 6from pathlib import Path 7from typing import Callable, Sequence 8 9from tools.stats.import_test_stats import get_disabled_tests 10from tools.testing.test_run import ShardedTest, TestRun 11 12 13REPO_ROOT = Path(__file__).resolve().parent.parent.parent 14 15IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1" 16BUILD_ENVIRONMENT = os.getenv("BUILD_ENVIRONMENT", "") 17USE_3_PROCS = "sm86" in BUILD_ENVIRONMENT or "cuda" not in BUILD_ENVIRONMENT 18 19# NUM_PROCS_FOR_SHARDING_CALC must remain consistent across all shards of a job 20# to ensure that sharding is consistent, NUM_PROCS is the actual number of procs 21# used to run tests. If they are not equal, the only consequence should be 22# unequal shards. 23IS_ROCM = os.path.exists("/opt/rocm") 24NUM_PROCS = 1 if IS_MEM_LEAK_CHECK else 3 if USE_3_PROCS else 2 25NUM_PROCS_FOR_SHARDING_CALC = NUM_PROCS if not IS_ROCM or IS_MEM_LEAK_CHECK else 2 26THRESHOLD = 60 * 10 # 10 minutes 27 28# See Note [ROCm parallel CI testing] 29# Special logic for ROCm GHA runners to query number of GPUs available. 30# torch.version.hip was not available to check if this was a ROCm self-hosted runner. 31# Must check for ROCm runner in another way. We look for /opt/rocm directory. 32if IS_ROCM and not IS_MEM_LEAK_CHECK: 33 try: 34 # This is the same logic used in GHA health check, see .github/templates/common.yml.j2 35 lines = ( 36 subprocess.check_output(["rocminfo"], encoding="ascii").strip().split("\n") 37 ) 38 count = 0 39 for line in lines: 40 if " gfx" in line: 41 count += 1 42 assert count > 0 # there must be at least 1 GPU 43 # Limiting to 8 GPUs(PROCS) 44 NUM_PROCS = min(count, 8) 45 except subprocess.CalledProcessError as e: 46 # The safe default for ROCm GHA runners is to run tests serially. 47 NUM_PROCS = 1 48 49 50class ShardJob: 51 def __init__(self) -> None: 52 self.serial: list[ShardedTest] = [] 53 self.parallel: list[ShardedTest] = [] 54 55 def get_total_time(self) -> float: 56 """Default is the value for which to substitute if a test has no time""" 57 procs = [0.0 for _ in range(NUM_PROCS_FOR_SHARDING_CALC)] 58 for test in self.parallel: 59 min_index = procs.index(min(procs)) 60 procs[min_index] += test.get_time() 61 time = max(procs) + sum(test.get_time() for test in self.serial) 62 return time 63 64 def convert_to_tuple(self) -> tuple[float, list[ShardedTest]]: 65 return (self.get_total_time(), self.serial + self.parallel) 66 67 68def get_with_pytest_shard( 69 tests: Sequence[TestRun], 70 test_file_times: dict[str, float], 71 test_class_times: dict[str, dict[str, float]] | None, 72) -> list[ShardedTest]: 73 sharded_tests: list[ShardedTest] = [] 74 75 for test in tests: 76 duration = get_duration(test, test_file_times, test_class_times or {}) 77 78 if duration and duration > THRESHOLD: 79 num_shards = math.ceil(duration / THRESHOLD) 80 for i in range(num_shards): 81 sharded_tests.append( 82 ShardedTest(test, i + 1, num_shards, duration / num_shards) 83 ) 84 else: 85 sharded_tests.append(ShardedTest(test, 1, 1, duration)) 86 return sharded_tests 87 88 89def get_duration( 90 test: TestRun, 91 test_file_times: dict[str, float], 92 test_class_times: dict[str, dict[str, float]], 93) -> float | None: 94 """Calculate the time for a TestRun based on the given test_file_times and 95 test_class_times. Returns None if the time is unknown.""" 96 file_duration = test_file_times.get(test.test_file, None) 97 if test.is_full_file(): 98 return file_duration 99 100 def get_duration_for_classes( 101 test_file: str, test_classes: frozenset[str] 102 ) -> float | None: 103 duration: float = 0 104 105 for test_class in test_classes: 106 class_duration = test_class_times.get(test_file, {}).get(test_class, None) 107 if class_duration is None: 108 return None 109 duration += class_duration 110 return duration 111 112 included = test.included() 113 excluded = test.excluded() 114 included_classes_duration = get_duration_for_classes(test.test_file, included) 115 excluded_classes_duration = get_duration_for_classes(test.test_file, excluded) 116 117 if included_classes_duration is None or excluded_classes_duration is None: 118 # Didn't get the time for all classes, so time is unknown 119 return None 120 121 if included: 122 return included_classes_duration 123 assert ( 124 excluded 125 ), f"TestRun {test} is not full file but doesn't have included or excluded classes" 126 if file_duration is None: 127 return None 128 return file_duration - excluded_classes_duration 129 130 131def shard( 132 sharded_jobs: list[ShardJob], 133 pytest_sharded_tests: Sequence[ShardedTest], 134 estimated_time_limit: float | None = None, 135 serial: bool = False, 136) -> None: 137 # Modifies sharded_jobs in place 138 if len(sharded_jobs) == 0: 139 assert ( 140 len(pytest_sharded_tests) == 0 141 ), "No shards provided but there are tests to shard" 142 return 143 144 round_robin_index = 0 145 146 def _get_min_sharded_job( 147 sharded_jobs: list[ShardJob], test: ShardedTest 148 ) -> ShardJob: 149 if test.time is None: 150 nonlocal round_robin_index 151 job = sharded_jobs[round_robin_index % len(sharded_jobs)] 152 round_robin_index += 1 153 return job 154 return min(sharded_jobs, key=lambda j: j.get_total_time()) 155 156 def _shard_serial( 157 tests: Sequence[ShardedTest], sharded_jobs: list[ShardJob] 158 ) -> None: 159 assert estimated_time_limit is not None, "Estimated time limit must be provided" 160 new_sharded_jobs = sharded_jobs 161 for test in tests: 162 if ( 163 len(sharded_jobs) > 1 164 and sharded_jobs[-1].get_total_time() > estimated_time_limit 165 ): 166 new_sharded_jobs = sharded_jobs[:-1] 167 min_sharded_job = _get_min_sharded_job(new_sharded_jobs, test) 168 min_sharded_job.serial.append(test) 169 170 def _shard_parallel( 171 tests: Sequence[ShardedTest], sharded_jobs: list[ShardJob] 172 ) -> None: 173 for test in tests: 174 min_sharded_job = _get_min_sharded_job(sharded_jobs, test) 175 min_sharded_job.parallel.append(test) 176 177 if serial: 178 _shard_serial(pytest_sharded_tests, sharded_jobs) 179 else: 180 _shard_parallel(pytest_sharded_tests, sharded_jobs) 181 182 return 183 184 185def calculate_shards( 186 num_shards: int, 187 tests: Sequence[TestRun], 188 test_file_times: dict[str, float], 189 test_class_times: dict[str, dict[str, float]] | None, 190 must_serial: Callable[[str], bool] | None = None, 191 sort_by_time: bool = True, 192) -> list[tuple[float, list[ShardedTest]]]: 193 must_serial = must_serial or (lambda x: True) 194 test_class_times = test_class_times or {} 195 196 # Divide tests into pytest shards 197 if sort_by_time: 198 known_tests = [ 199 x 200 for x in tests 201 if get_duration(x, test_file_times, test_class_times) is not None 202 ] 203 unknown_tests = [x for x in tests if x not in known_tests] 204 205 pytest_sharded_tests = sorted( 206 get_with_pytest_shard(known_tests, test_file_times, test_class_times), 207 key=lambda j: j.get_time(), 208 reverse=True, 209 ) + get_with_pytest_shard(unknown_tests, test_file_times, test_class_times) 210 else: 211 pytest_sharded_tests = get_with_pytest_shard( 212 tests, test_file_times, test_class_times 213 ) 214 del tests 215 216 serial_tests = [test for test in pytest_sharded_tests if must_serial(test.name)] 217 parallel_tests = [test for test in pytest_sharded_tests if test not in serial_tests] 218 219 serial_time = sum(test.get_time() for test in serial_tests) 220 parallel_time = sum(test.get_time() for test in parallel_tests) 221 total_time = serial_time + parallel_time / NUM_PROCS_FOR_SHARDING_CALC 222 estimated_time_per_shard = total_time / num_shards 223 # Separate serial tests from parallel tests as much as possible to maximize 224 # parallelism by putting all the serial tests on the first num_serial_shards 225 # shards. The estimated_time_limit is the estimated time it should take for 226 # the least filled serial shard. Ex if we have 8 min of serial tests, 20 min 227 # of parallel tests, 6 shards, and 2 procs per machine, we would expect each 228 # machine to take 3 min and should aim for 3 serial shards, with shards 1 229 # and 2 taking 3 min and shard 3 taking 2 min. The estimated time limit 230 # would be 2 min. This ensures that the first few shard contains as many 231 # serial tests as possible and as few parallel tests as possible. The least 232 # filled/last (in the example, the 3rd) shard may contain a lot of both 233 # serial and parallel tests. 234 estimated_time_limit = 0.0 235 if estimated_time_per_shard != 0: 236 estimated_time_limit = serial_time % estimated_time_per_shard 237 if estimated_time_limit <= 0.01: 238 estimated_time_limit = estimated_time_per_shard 239 if total_time == 0: 240 num_serial_shards = num_shards 241 else: 242 num_serial_shards = max(math.ceil(serial_time / total_time * num_shards), 1) 243 244 sharded_jobs = [ShardJob() for _ in range(num_shards)] 245 shard( 246 sharded_jobs=sharded_jobs[:num_serial_shards], 247 pytest_sharded_tests=serial_tests, 248 estimated_time_limit=estimated_time_limit, 249 serial=True, 250 ) 251 shard( 252 sharded_jobs=sharded_jobs, 253 pytest_sharded_tests=parallel_tests, 254 serial=False, 255 ) 256 257 return [job.convert_to_tuple() for job in sharded_jobs] 258 259 260def get_test_case_configs(dirpath: str) -> None: 261 get_disabled_tests(dirpath=dirpath) 262