xref: /aosp_15_r20/external/pytorch/test/run_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport argparse
4*da0073e9SAndroid Build Coastguard Workerimport copy
5*da0073e9SAndroid Build Coastguard Workerimport glob
6*da0073e9SAndroid Build Coastguard Workerimport json
7*da0073e9SAndroid Build Coastguard Workerimport os
8*da0073e9SAndroid Build Coastguard Workerimport re
9*da0073e9SAndroid Build Coastguard Workerimport shutil
10*da0073e9SAndroid Build Coastguard Workerimport signal
11*da0073e9SAndroid Build Coastguard Workerimport subprocess
12*da0073e9SAndroid Build Coastguard Workerimport sys
13*da0073e9SAndroid Build Coastguard Workerimport tempfile
14*da0073e9SAndroid Build Coastguard Workerimport time
15*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict
16*da0073e9SAndroid Build Coastguard Workerfrom contextlib import ExitStack
17*da0073e9SAndroid Build Coastguard Workerfrom datetime import datetime
18*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path
19*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, cast, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Workerimport pkg_resources
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Workerimport torch
24*da0073e9SAndroid Build Coastguard Workerimport torch.distributed as dist
25*da0073e9SAndroid Build Coastguard Workerfrom torch.multiprocessing import current_process, get_context
26*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
27*da0073e9SAndroid Build Coastguard Worker    get_report_path,
28*da0073e9SAndroid Build Coastguard Worker    IS_CI,
29*da0073e9SAndroid Build Coastguard Worker    IS_MACOS,
30*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS,
31*da0073e9SAndroid Build Coastguard Worker    retry_shell,
32*da0073e9SAndroid Build Coastguard Worker    set_cwd,
33*da0073e9SAndroid Build Coastguard Worker    shell,
34*da0073e9SAndroid Build Coastguard Worker    TEST_CUDA,
35*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ASAN,
36*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_CROSSREF,
37*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ROCM,
38*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_SLOW_GRADCHECK,
39*da0073e9SAndroid Build Coastguard Worker)
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker# using tools/ to optimize test run.
43*da0073e9SAndroid Build Coastguard WorkerREPO_ROOT = Path(__file__).resolve().parent.parent
44*da0073e9SAndroid Build Coastguard Workersys.path.insert(0, str(REPO_ROOT))
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Workerfrom tools.stats.import_test_stats import (
47*da0073e9SAndroid Build Coastguard Worker    ADDITIONAL_CI_FILES_FOLDER,
48*da0073e9SAndroid Build Coastguard Worker    TEST_CLASS_TIMES_FILE,
49*da0073e9SAndroid Build Coastguard Worker    TEST_TIMES_FILE,
50*da0073e9SAndroid Build Coastguard Worker)
51*da0073e9SAndroid Build Coastguard Workerfrom tools.stats.upload_metrics import add_global_metric, emit_metric
52*da0073e9SAndroid Build Coastguard Workerfrom tools.testing.discover_tests import (
53*da0073e9SAndroid Build Coastguard Worker    CPP_TEST_PATH,
54*da0073e9SAndroid Build Coastguard Worker    CPP_TEST_PREFIX,
55*da0073e9SAndroid Build Coastguard Worker    CPP_TESTS_DIR,
56*da0073e9SAndroid Build Coastguard Worker    parse_test_module,
57*da0073e9SAndroid Build Coastguard Worker    TESTS,
58*da0073e9SAndroid Build Coastguard Worker)
59*da0073e9SAndroid Build Coastguard Workerfrom tools.testing.do_target_determination_for_s3 import import_results
60*da0073e9SAndroid Build Coastguard Workerfrom tools.testing.target_determination.gen_artifact import gen_ci_artifact
61*da0073e9SAndroid Build Coastguard Workerfrom tools.testing.target_determination.heuristics.previously_failed_in_pr import (
62*da0073e9SAndroid Build Coastguard Worker    gen_additional_test_failures_file,
63*da0073e9SAndroid Build Coastguard Worker)
64*da0073e9SAndroid Build Coastguard Workerfrom tools.testing.target_determination.heuristics.utils import get_pr_number
65*da0073e9SAndroid Build Coastguard Workerfrom tools.testing.test_run import TestRun
66*da0073e9SAndroid Build Coastguard Workerfrom tools.testing.test_selections import (
67*da0073e9SAndroid Build Coastguard Worker    calculate_shards,
68*da0073e9SAndroid Build Coastguard Worker    get_test_case_configs,
69*da0073e9SAndroid Build Coastguard Worker    NUM_PROCS,
70*da0073e9SAndroid Build Coastguard Worker    ShardedTest,
71*da0073e9SAndroid Build Coastguard Worker    THRESHOLD,
72*da0073e9SAndroid Build Coastguard Worker)
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker# Make sure to remove REPO_ROOT after import is done
76*da0073e9SAndroid Build Coastguard Workersys.path.remove(str(REPO_ROOT))
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard WorkerHAVE_TEST_SELECTION_TOOLS = True
80*da0073e9SAndroid Build Coastguard WorkerTEST_CONFIG = os.getenv("TEST_CONFIG", "")
81*da0073e9SAndroid Build Coastguard WorkerBUILD_ENVIRONMENT = os.getenv("BUILD_ENVIRONMENT", "")
82*da0073e9SAndroid Build Coastguard WorkerRERUN_DISABLED_TESTS = os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1"
83*da0073e9SAndroid Build Coastguard WorkerDISTRIBUTED_TEST_PREFIX = "distributed"
84*da0073e9SAndroid Build Coastguard WorkerINDUCTOR_TEST_PREFIX = "inductor"
85*da0073e9SAndroid Build Coastguard WorkerIS_SLOW = "slow" in TEST_CONFIG or "slow" in BUILD_ENVIRONMENT
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker# Note [ROCm parallel CI testing]
89*da0073e9SAndroid Build Coastguard Worker# https://github.com/pytorch/pytorch/pull/85770 added file-granularity parallel testing.
90*da0073e9SAndroid Build Coastguard Worker# In .ci/pytorch/test.sh, TEST_CONFIG == "default", CUDA and HIP_VISIBLE_DEVICES is set to 0.
91*da0073e9SAndroid Build Coastguard Worker# This results in multiple test files sharing the same GPU.
92*da0073e9SAndroid Build Coastguard Worker# This should be a supported use case for ROCm, but it exposed issues in the kernel driver resulting in hangs.
93*da0073e9SAndroid Build Coastguard Worker# See https://github.com/pytorch/pytorch/issues/90940.
94*da0073e9SAndroid Build Coastguard Worker#
95*da0073e9SAndroid Build Coastguard Worker# Further, ROCm self-hosted runners have up to 4 GPUs.
96*da0073e9SAndroid Build Coastguard Worker# Device visibility was set to 0 to match CUDA test behavior, but this was wasting available GPU resources.
97*da0073e9SAndroid Build Coastguard Worker# Assigning each Pool worker their own dedicated GPU avoids the ROCm oversubscription issues.
98*da0073e9SAndroid Build Coastguard Worker# This should also result in better overall wall clock time since all GPUs can be utilized.
99*da0073e9SAndroid Build Coastguard Workerdef maybe_set_hip_visible_devies():
100*da0073e9SAndroid Build Coastguard Worker    # Special handling of ROCm GHA runners for parallel (file granularity) tests.
101*da0073e9SAndroid Build Coastguard Worker    if torch.version.hip:
102*da0073e9SAndroid Build Coastguard Worker        p = current_process()
103*da0073e9SAndroid Build Coastguard Worker        if p.name != "MainProcess":
104*da0073e9SAndroid Build Coastguard Worker            # this is a Process from a parallel Pool, not the MainProcess
105*da0073e9SAndroid Build Coastguard Worker            os.environ["HIP_VISIBLE_DEVICES"] = str(p._identity[0] % NUM_PROCS)
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Workerdef strtobool(s):
109*da0073e9SAndroid Build Coastguard Worker    return s.lower() not in {"", "0", "false", "off"}
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Workerclass TestChoices(list):
113*da0073e9SAndroid Build Coastguard Worker    def __init__(self, *args, **kwargs):
114*da0073e9SAndroid Build Coastguard Worker        super().__init__(args[0])
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    def __contains__(self, item):
117*da0073e9SAndroid Build Coastguard Worker        return list.__contains__(self, parse_test_module(item))
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard WorkerFSDP_TEST = [test for test in TESTS if test.startswith("distributed/fsdp")]
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard WorkerWINDOWS_BLOCKLIST = [
123*da0073e9SAndroid Build Coastguard Worker    "distributed/nn/jit/test_instantiator",
124*da0073e9SAndroid Build Coastguard Worker    "distributed/rpc/test_faulty_agent",
125*da0073e9SAndroid Build Coastguard Worker    "distributed/rpc/test_tensorpipe_agent",
126*da0073e9SAndroid Build Coastguard Worker    "distributed/rpc/test_share_memory",
127*da0073e9SAndroid Build Coastguard Worker    "distributed/rpc/cuda/test_tensorpipe_agent",
128*da0073e9SAndroid Build Coastguard Worker    "distributed/pipeline/sync/skip/test_api",
129*da0073e9SAndroid Build Coastguard Worker    "distributed/pipeline/sync/skip/test_gpipe",
130*da0073e9SAndroid Build Coastguard Worker    "distributed/pipeline/sync/skip/test_inspect_skip_layout",
131*da0073e9SAndroid Build Coastguard Worker    "distributed/pipeline/sync/skip/test_leak",
132*da0073e9SAndroid Build Coastguard Worker    "distributed/pipeline/sync/skip/test_portal",
133*da0073e9SAndroid Build Coastguard Worker    "distributed/pipeline/sync/skip/test_stash_pop",
134*da0073e9SAndroid Build Coastguard Worker    "distributed/pipeline/sync/skip/test_tracker",
135*da0073e9SAndroid Build Coastguard Worker    "distributed/pipeline/sync/skip/test_verify_skippables",
136*da0073e9SAndroid Build Coastguard Worker    "distributed/pipeline/sync/test_balance",
137*da0073e9SAndroid Build Coastguard Worker    "distributed/pipeline/sync/test_bugs",
138*da0073e9SAndroid Build Coastguard Worker    "distributed/pipeline/sync/test_checkpoint",
139*da0073e9SAndroid Build Coastguard Worker    "distributed/pipeline/sync/test_copy",
140*da0073e9SAndroid Build Coastguard Worker    "distributed/pipeline/sync/test_deferred_batch_norm",
141*da0073e9SAndroid Build Coastguard Worker    "distributed/pipeline/sync/test_dependency",
142*da0073e9SAndroid Build Coastguard Worker    "distributed/pipeline/sync/test_inplace",
143*da0073e9SAndroid Build Coastguard Worker    "distributed/pipeline/sync/test_microbatch",
144*da0073e9SAndroid Build Coastguard Worker    "distributed/pipeline/sync/test_phony",
145*da0073e9SAndroid Build Coastguard Worker    "distributed/pipeline/sync/test_pipe",
146*da0073e9SAndroid Build Coastguard Worker    "distributed/pipeline/sync/test_pipeline",
147*da0073e9SAndroid Build Coastguard Worker    "distributed/pipeline/sync/test_stream",
148*da0073e9SAndroid Build Coastguard Worker    "distributed/pipeline/sync/test_transparency",
149*da0073e9SAndroid Build Coastguard Worker    "distributed/pipeline/sync/test_worker",
150*da0073e9SAndroid Build Coastguard Worker    "distributed/elastic/agent/server/test/api_test",
151*da0073e9SAndroid Build Coastguard Worker    "distributed/elastic/multiprocessing/api_test",
152*da0073e9SAndroid Build Coastguard Worker    "distributed/_shard/checkpoint/test_checkpoint"
153*da0073e9SAndroid Build Coastguard Worker    "distributed/_shard/checkpoint/test_file_system_checkpoint"
154*da0073e9SAndroid Build Coastguard Worker    "distributed/_shard/sharding_spec/test_sharding_spec",
155*da0073e9SAndroid Build Coastguard Worker    "distributed/_shard/sharding_plan/test_sharding_plan",
156*da0073e9SAndroid Build Coastguard Worker    "distributed/_shard/sharded_tensor/test_sharded_tensor",
157*da0073e9SAndroid Build Coastguard Worker    "distributed/_shard/sharded_tensor/test_sharded_tensor_reshard",
158*da0073e9SAndroid Build Coastguard Worker    "distributed/_shard/sharded_tensor/ops/test_embedding",
159*da0073e9SAndroid Build Coastguard Worker    "distributed/_shard/sharded_tensor/ops/test_embedding_bag",
160*da0073e9SAndroid Build Coastguard Worker    "distributed/_shard/sharded_tensor/ops/test_binary_cmp",
161*da0073e9SAndroid Build Coastguard Worker    "distributed/_shard/sharded_tensor/ops/test_init",
162*da0073e9SAndroid Build Coastguard Worker    "distributed/_shard/sharded_optim/test_sharded_optim",
163*da0073e9SAndroid Build Coastguard Worker] + FSDP_TEST
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard WorkerROCM_BLOCKLIST = [
166*da0073e9SAndroid Build Coastguard Worker    "distributed/rpc/test_faulty_agent",
167*da0073e9SAndroid Build Coastguard Worker    "distributed/rpc/test_tensorpipe_agent",
168*da0073e9SAndroid Build Coastguard Worker    "distributed/rpc/test_share_memory",
169*da0073e9SAndroid Build Coastguard Worker    "distributed/rpc/cuda/test_tensorpipe_agent",
170*da0073e9SAndroid Build Coastguard Worker    "distributed/_shard/checkpoint/test_checkpoint"
171*da0073e9SAndroid Build Coastguard Worker    "distributed/_shard/checkpoint/test_file_system_checkpoint"
172*da0073e9SAndroid Build Coastguard Worker    "distributed/_shard/sharding_spec/test_sharding_spec",
173*da0073e9SAndroid Build Coastguard Worker    "distributed/_shard/sharding_plan/test_sharding_plan",
174*da0073e9SAndroid Build Coastguard Worker    "distributed/_shard/sharded_tensor/test_sharded_tensor",
175*da0073e9SAndroid Build Coastguard Worker    "distributed/_shard/sharded_tensor/test_sharded_tensor_reshard",
176*da0073e9SAndroid Build Coastguard Worker    "distributed/_shard/sharded_tensor/ops/test_embedding",
177*da0073e9SAndroid Build Coastguard Worker    "distributed/_shard/sharded_tensor/ops/test_embedding_bag",
178*da0073e9SAndroid Build Coastguard Worker    "distributed/_shard/sharded_tensor/ops/test_binary_cmp",
179*da0073e9SAndroid Build Coastguard Worker    "distributed/_shard/sharded_tensor/ops/test_init",
180*da0073e9SAndroid Build Coastguard Worker    "distributed/_shard/sharded_optim/test_sharded_optim",
181*da0073e9SAndroid Build Coastguard Worker    "test_determination",
182*da0073e9SAndroid Build Coastguard Worker    "test_jit_legacy",
183*da0073e9SAndroid Build Coastguard Worker    "test_cuda_nvml_based_avail",
184*da0073e9SAndroid Build Coastguard Worker    "test_jit_cuda_fuser",
185*da0073e9SAndroid Build Coastguard Worker    "distributed/_tensor/test_attention",
186*da0073e9SAndroid Build Coastguard Worker]
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard WorkerXPU_BLOCKLIST = [
189*da0073e9SAndroid Build Coastguard Worker    "test_autograd",
190*da0073e9SAndroid Build Coastguard Worker    "profiler/test_cpp_thread",
191*da0073e9SAndroid Build Coastguard Worker    "profiler/test_execution_trace",
192*da0073e9SAndroid Build Coastguard Worker    "profiler/test_memory_profiler",
193*da0073e9SAndroid Build Coastguard Worker    "profiler/test_profiler",
194*da0073e9SAndroid Build Coastguard Worker    "profiler/test_profiler_tree",
195*da0073e9SAndroid Build Coastguard Worker    "profiler/test_record_function",
196*da0073e9SAndroid Build Coastguard Worker    "profiler/test_torch_tidy",
197*da0073e9SAndroid Build Coastguard Worker]
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard WorkerXPU_TEST = [
200*da0073e9SAndroid Build Coastguard Worker    "test_xpu",
201*da0073e9SAndroid Build Coastguard Worker]
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker# The tests inside these files should never be run in parallel with each other
204*da0073e9SAndroid Build Coastguard WorkerRUN_PARALLEL_BLOCKLIST = [
205*da0073e9SAndroid Build Coastguard Worker    "test_cpp_extensions_jit",
206*da0073e9SAndroid Build Coastguard Worker    "test_cpp_extensions_open_device_registration",
207*da0073e9SAndroid Build Coastguard Worker    "test_cpp_extensions_stream_and_event",
208*da0073e9SAndroid Build Coastguard Worker    "test_cpp_extensions_mtia_backend",
209*da0073e9SAndroid Build Coastguard Worker    "test_jit_disabled",
210*da0073e9SAndroid Build Coastguard Worker    "test_mobile_optimizer",
211*da0073e9SAndroid Build Coastguard Worker    "test_multiprocessing",
212*da0073e9SAndroid Build Coastguard Worker    "test_multiprocessing_spawn",
213*da0073e9SAndroid Build Coastguard Worker    "test_namedtuple_return_api",
214*da0073e9SAndroid Build Coastguard Worker    "test_overrides",
215*da0073e9SAndroid Build Coastguard Worker    "test_show_pickle",
216*da0073e9SAndroid Build Coastguard Worker    "test_tensorexpr",
217*da0073e9SAndroid Build Coastguard Worker    "test_cuda_primary_ctx",
218*da0073e9SAndroid Build Coastguard Worker    "test_cuda_trace",
219*da0073e9SAndroid Build Coastguard Worker    "inductor/test_benchmark_fusion",
220*da0073e9SAndroid Build Coastguard Worker    "test_cuda_nvml_based_avail",
221*da0073e9SAndroid Build Coastguard Worker    # temporarily sets a global config
222*da0073e9SAndroid Build Coastguard Worker    "test_autograd_fallback",
223*da0073e9SAndroid Build Coastguard Worker] + FSDP_TEST
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker# Test files that should always be run serially with other test files,
226*da0073e9SAndroid Build Coastguard Worker# but it's okay if the tests inside them are run in parallel with each other.
227*da0073e9SAndroid Build Coastguard WorkerCI_SERIAL_LIST = [
228*da0073e9SAndroid Build Coastguard Worker    "test_nn",
229*da0073e9SAndroid Build Coastguard Worker    "test_fake_tensor",
230*da0073e9SAndroid Build Coastguard Worker    "test_cpp_api_parity",
231*da0073e9SAndroid Build Coastguard Worker    "test_reductions",
232*da0073e9SAndroid Build Coastguard Worker    "test_fx_backends",
233*da0073e9SAndroid Build Coastguard Worker    "test_cpp_extensions_jit",
234*da0073e9SAndroid Build Coastguard Worker    "test_torch",
235*da0073e9SAndroid Build Coastguard Worker    "test_tensor_creation_ops",
236*da0073e9SAndroid Build Coastguard Worker    "test_dispatch",
237*da0073e9SAndroid Build Coastguard Worker    "test_python_dispatch",  # torch.library creation and deletion must be serialized
238*da0073e9SAndroid Build Coastguard Worker    "test_spectral_ops",  # Cause CUDA illegal memory access https://github.com/pytorch/pytorch/issues/88916
239*da0073e9SAndroid Build Coastguard Worker    "nn/test_pooling",
240*da0073e9SAndroid Build Coastguard Worker    "nn/test_convolution",  # Doesn't respect set_per_process_memory_fraction, results in OOM for other tests in slow gradcheck
241*da0073e9SAndroid Build Coastguard Worker    "distributions/test_distributions",
242*da0073e9SAndroid Build Coastguard Worker    "test_fx",  # gets SIGKILL
243*da0073e9SAndroid Build Coastguard Worker    "functorch/test_memory_efficient_fusion",  # Cause CUDA OOM on ROCm
244*da0073e9SAndroid Build Coastguard Worker    "test_utils",  # OOM
245*da0073e9SAndroid Build Coastguard Worker    "test_sort_and_select",  # OOM
246*da0073e9SAndroid Build Coastguard Worker    "test_backward_compatible_arguments",  # OOM
247*da0073e9SAndroid Build Coastguard Worker    "test_autocast",  # OOM
248*da0073e9SAndroid Build Coastguard Worker    "test_native_mha",  # OOM
249*da0073e9SAndroid Build Coastguard Worker    "test_module_hooks",  # OOM
250*da0073e9SAndroid Build Coastguard Worker    "inductor/test_max_autotune",
251*da0073e9SAndroid Build Coastguard Worker    "inductor/test_cutlass_backend",  # slow due to many nvcc compilation steps,
252*da0073e9SAndroid Build Coastguard Worker    "inductor/test_flex_attention",  # OOM
253*da0073e9SAndroid Build Coastguard Worker]
254*da0073e9SAndroid Build Coastguard Worker# A subset of onnx tests that cannot run in parallel due to high memory usage.
255*da0073e9SAndroid Build Coastguard WorkerONNX_SERIAL_LIST = [
256*da0073e9SAndroid Build Coastguard Worker    "onnx/test_models",
257*da0073e9SAndroid Build Coastguard Worker    "onnx/test_models_quantized_onnxruntime",
258*da0073e9SAndroid Build Coastguard Worker    "onnx/test_models_onnxruntime",
259*da0073e9SAndroid Build Coastguard Worker    "onnx/test_custom_ops",
260*da0073e9SAndroid Build Coastguard Worker    "onnx/test_utility_funs",
261*da0073e9SAndroid Build Coastguard Worker]
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker# A subset of our TEST list that validates PyTorch's ops, modules, and autograd function as expected
264*da0073e9SAndroid Build Coastguard WorkerCORE_TEST_LIST = [
265*da0073e9SAndroid Build Coastguard Worker    "test_autograd",
266*da0073e9SAndroid Build Coastguard Worker    "test_autograd_fallback",
267*da0073e9SAndroid Build Coastguard Worker    "test_modules",
268*da0073e9SAndroid Build Coastguard Worker    "test_nn",
269*da0073e9SAndroid Build Coastguard Worker    "test_ops",
270*da0073e9SAndroid Build Coastguard Worker    "test_ops_gradients",
271*da0073e9SAndroid Build Coastguard Worker    "test_ops_fwd_gradients",
272*da0073e9SAndroid Build Coastguard Worker    "test_ops_jit",
273*da0073e9SAndroid Build Coastguard Worker    "test_torch",
274*da0073e9SAndroid Build Coastguard Worker]
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker# if a test file takes longer than 5 min, we add it to TARGET_DET_LIST
278*da0073e9SAndroid Build Coastguard WorkerSLOW_TEST_THRESHOLD = 300
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard WorkerDISTRIBUTED_TESTS_CONFIG = {}
281*da0073e9SAndroid Build Coastguard Worker
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Workerif dist.is_available():
284*da0073e9SAndroid Build Coastguard Worker    DISTRIBUTED_TESTS_CONFIG["test"] = {"WORLD_SIZE": "1"}
285*da0073e9SAndroid Build Coastguard Worker    if not TEST_WITH_ROCM and dist.is_mpi_available():
286*da0073e9SAndroid Build Coastguard Worker        DISTRIBUTED_TESTS_CONFIG["mpi"] = {
287*da0073e9SAndroid Build Coastguard Worker            "WORLD_SIZE": "3",
288*da0073e9SAndroid Build Coastguard Worker            "TEST_REPORT_SOURCE_OVERRIDE": "dist-mpi",
289*da0073e9SAndroid Build Coastguard Worker        }
290*da0073e9SAndroid Build Coastguard Worker    if dist.is_nccl_available():
291*da0073e9SAndroid Build Coastguard Worker        DISTRIBUTED_TESTS_CONFIG["nccl"] = {
292*da0073e9SAndroid Build Coastguard Worker            "WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3",
293*da0073e9SAndroid Build Coastguard Worker            "TEST_REPORT_SOURCE_OVERRIDE": "dist-nccl",
294*da0073e9SAndroid Build Coastguard Worker        }
295*da0073e9SAndroid Build Coastguard Worker    if dist.is_gloo_available():
296*da0073e9SAndroid Build Coastguard Worker        DISTRIBUTED_TESTS_CONFIG["gloo"] = {
297*da0073e9SAndroid Build Coastguard Worker            "WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3",
298*da0073e9SAndroid Build Coastguard Worker            "TEST_REPORT_SOURCE_OVERRIDE": "dist-gloo",
299*da0073e9SAndroid Build Coastguard Worker        }
300*da0073e9SAndroid Build Coastguard Worker    if dist.is_ucc_available():
301*da0073e9SAndroid Build Coastguard Worker        DISTRIBUTED_TESTS_CONFIG["ucc"] = {
302*da0073e9SAndroid Build Coastguard Worker            "WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3",
303*da0073e9SAndroid Build Coastguard Worker            "TEST_REPORT_SOURCE_OVERRIDE": "dist-ucc",
304*da0073e9SAndroid Build Coastguard Worker            "UCX_TLS": "tcp,cuda",
305*da0073e9SAndroid Build Coastguard Worker            "UCC_TLS": "nccl,ucp,cuda",
306*da0073e9SAndroid Build Coastguard Worker            "UCC_TL_UCP_TUNE": "cuda:0",  # don't use UCP TL on CUDA as it is not well supported
307*da0073e9SAndroid Build Coastguard Worker            "UCC_EC_CUDA_USE_COOPERATIVE_LAUNCH": "n",  # CI nodes (M60) fail if it is on
308*da0073e9SAndroid Build Coastguard Worker        }
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker# https://stackoverflow.com/questions/2549939/get-signal-names-from-numbers-in-python
311*da0073e9SAndroid Build Coastguard WorkerSIGNALS_TO_NAMES_DICT = {
312*da0073e9SAndroid Build Coastguard Worker    getattr(signal, n): n for n in dir(signal) if n.startswith("SIG") and "_" not in n
313*da0073e9SAndroid Build Coastguard Worker}
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard WorkerCPP_EXTENSIONS_ERROR = """
316*da0073e9SAndroid Build Coastguard WorkerNinja (https://ninja-build.org) is required for some of the C++ extensions
317*da0073e9SAndroid Build Coastguard Workertests, but it could not be found. Install ninja with `pip install ninja`
318*da0073e9SAndroid Build Coastguard Workeror `conda install ninja`. Alternatively, disable said tests with
319*da0073e9SAndroid Build Coastguard Worker`run_test.py --exclude test_cpp_extensions_aot_ninja test_cpp_extensions_jit`.
320*da0073e9SAndroid Build Coastguard Worker"""
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard WorkerPYTORCH_COLLECT_COVERAGE = bool(os.environ.get("PYTORCH_COLLECT_COVERAGE"))
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard WorkerJIT_EXECUTOR_TESTS = [
325*da0073e9SAndroid Build Coastguard Worker    "test_jit_profiling",
326*da0073e9SAndroid Build Coastguard Worker    "test_jit_legacy",
327*da0073e9SAndroid Build Coastguard Worker    "test_jit_fuser_legacy",
328*da0073e9SAndroid Build Coastguard Worker]
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard WorkerINDUCTOR_TESTS = [test for test in TESTS if test.startswith(INDUCTOR_TEST_PREFIX)]
331*da0073e9SAndroid Build Coastguard WorkerDISTRIBUTED_TESTS = [test for test in TESTS if test.startswith(DISTRIBUTED_TEST_PREFIX)]
332*da0073e9SAndroid Build Coastguard WorkerTORCH_EXPORT_TESTS = [test for test in TESTS if test.startswith("export")]
333*da0073e9SAndroid Build Coastguard WorkerFUNCTORCH_TESTS = [test for test in TESTS if test.startswith("functorch")]
334*da0073e9SAndroid Build Coastguard WorkerONNX_TESTS = [test for test in TESTS if test.startswith("onnx")]
335*da0073e9SAndroid Build Coastguard WorkerCPP_TESTS = [test for test in TESTS if test.startswith(CPP_TEST_PREFIX)]
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard WorkerTESTS_REQUIRING_LAPACK = [
338*da0073e9SAndroid Build Coastguard Worker    "distributions/test_constraints",
339*da0073e9SAndroid Build Coastguard Worker    "distributions/test_distributions",
340*da0073e9SAndroid Build Coastguard Worker]
341*da0073e9SAndroid Build Coastguard Worker
342*da0073e9SAndroid Build Coastguard Worker# These are just the slowest ones, this isn't an exhaustive list.
343*da0073e9SAndroid Build Coastguard WorkerTESTS_NOT_USING_GRADCHECK = [
344*da0073e9SAndroid Build Coastguard Worker    # Note that you should use skipIfSlowGradcheckEnv if you do not wish to
345*da0073e9SAndroid Build Coastguard Worker    # skip all the tests in that file, e.g. test_mps
346*da0073e9SAndroid Build Coastguard Worker    "doctests",
347*da0073e9SAndroid Build Coastguard Worker    "test_meta",
348*da0073e9SAndroid Build Coastguard Worker    "test_hub",
349*da0073e9SAndroid Build Coastguard Worker    "test_fx",
350*da0073e9SAndroid Build Coastguard Worker    "test_decomp",
351*da0073e9SAndroid Build Coastguard Worker    "test_cpp_extensions_jit",
352*da0073e9SAndroid Build Coastguard Worker    "test_jit",
353*da0073e9SAndroid Build Coastguard Worker    "test_ops",
354*da0073e9SAndroid Build Coastguard Worker    "test_ops_jit",
355*da0073e9SAndroid Build Coastguard Worker    "dynamo/test_recompile_ux",
356*da0073e9SAndroid Build Coastguard Worker    "inductor/test_smoke",
357*da0073e9SAndroid Build Coastguard Worker    "test_quantization",
358*da0073e9SAndroid Build Coastguard Worker]
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Workerdef print_to_stderr(message):
362*da0073e9SAndroid Build Coastguard Worker    print(message, file=sys.stderr)
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Workerdef get_executable_command(options, disable_coverage=False, is_cpp_test=False):
366*da0073e9SAndroid Build Coastguard Worker    if options.coverage and not disable_coverage:
367*da0073e9SAndroid Build Coastguard Worker        if not is_cpp_test:
368*da0073e9SAndroid Build Coastguard Worker            executable = ["coverage", "run", "--parallel-mode", "--source=torch"]
369*da0073e9SAndroid Build Coastguard Worker        else:
370*da0073e9SAndroid Build Coastguard Worker            # TODO: C++ with coverage is not yet supported
371*da0073e9SAndroid Build Coastguard Worker            executable = []
372*da0073e9SAndroid Build Coastguard Worker    else:
373*da0073e9SAndroid Build Coastguard Worker        if not is_cpp_test:
374*da0073e9SAndroid Build Coastguard Worker            executable = [sys.executable, "-bb"]
375*da0073e9SAndroid Build Coastguard Worker        else:
376*da0073e9SAndroid Build Coastguard Worker            executable = ["pytest"]
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker    return executable
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Workerdef run_test(
382*da0073e9SAndroid Build Coastguard Worker    test_module: ShardedTest,
383*da0073e9SAndroid Build Coastguard Worker    test_directory,
384*da0073e9SAndroid Build Coastguard Worker    options,
385*da0073e9SAndroid Build Coastguard Worker    launcher_cmd=None,
386*da0073e9SAndroid Build Coastguard Worker    extra_unittest_args=None,
387*da0073e9SAndroid Build Coastguard Worker    env=None,
388*da0073e9SAndroid Build Coastguard Worker    print_log=True,
389*da0073e9SAndroid Build Coastguard Worker) -> int:
390*da0073e9SAndroid Build Coastguard Worker    scribe_token = os.getenv("SCRIBE_GRAPHQL_ACCESS_TOKEN", "")
391*da0073e9SAndroid Build Coastguard Worker    if scribe_token:
392*da0073e9SAndroid Build Coastguard Worker        print_to_stderr("SCRIBE_GRAPHQL_ACCESS_TOKEN is set")
393*da0073e9SAndroid Build Coastguard Worker    else:
394*da0073e9SAndroid Build Coastguard Worker        print_to_stderr("SCRIBE_GRAPHQL_ACCESS_TOKEN is NOT set")
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker    env = env or os.environ.copy()
397*da0073e9SAndroid Build Coastguard Worker    maybe_set_hip_visible_devies()
398*da0073e9SAndroid Build Coastguard Worker    unittest_args = options.additional_args.copy()
399*da0073e9SAndroid Build Coastguard Worker    test_file = test_module.name
400*da0073e9SAndroid Build Coastguard Worker    stepcurrent_key = test_file
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker    is_distributed_test = test_file.startswith(DISTRIBUTED_TEST_PREFIX)
403*da0073e9SAndroid Build Coastguard Worker    is_cpp_test = test_file.startswith(CPP_TEST_PREFIX)
404*da0073e9SAndroid Build Coastguard Worker    # NB: Rerun disabled tests depends on pytest-flakefinder and it doesn't work with
405*da0073e9SAndroid Build Coastguard Worker    # pytest-cpp atm. We also don't have support to disable C++ test yet, so it's ok
406*da0073e9SAndroid Build Coastguard Worker    # to just return successfully here
407*da0073e9SAndroid Build Coastguard Worker    if is_cpp_test and RERUN_DISABLED_TESTS:
408*da0073e9SAndroid Build Coastguard Worker        print_to_stderr(
409*da0073e9SAndroid Build Coastguard Worker            "Skipping C++ tests when running under RERUN_DISABLED_TESTS mode"
410*da0073e9SAndroid Build Coastguard Worker        )
411*da0073e9SAndroid Build Coastguard Worker        return 0
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Worker    if is_cpp_test:
414*da0073e9SAndroid Build Coastguard Worker        stepcurrent_key = f"{test_file}_{os.urandom(8).hex()}"
415*da0073e9SAndroid Build Coastguard Worker    else:
416*da0073e9SAndroid Build Coastguard Worker        unittest_args.extend(
417*da0073e9SAndroid Build Coastguard Worker            [
418*da0073e9SAndroid Build Coastguard Worker                f"--shard-id={test_module.shard}",
419*da0073e9SAndroid Build Coastguard Worker                f"--num-shards={test_module.num_shards}",
420*da0073e9SAndroid Build Coastguard Worker            ]
421*da0073e9SAndroid Build Coastguard Worker        )
422*da0073e9SAndroid Build Coastguard Worker        stepcurrent_key = f"{test_file}_{test_module.shard}_{os.urandom(8).hex()}"
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker    if options.verbose:
425*da0073e9SAndroid Build Coastguard Worker        unittest_args.append(f'-{"v" * options.verbose}')  # in case of pytest
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker    if test_file in RUN_PARALLEL_BLOCKLIST:
428*da0073e9SAndroid Build Coastguard Worker        unittest_args = [
429*da0073e9SAndroid Build Coastguard Worker            arg for arg in unittest_args if not arg.startswith("--run-parallel")
430*da0073e9SAndroid Build Coastguard Worker        ]
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker    if extra_unittest_args:
433*da0073e9SAndroid Build Coastguard Worker        assert isinstance(extra_unittest_args, list)
434*da0073e9SAndroid Build Coastguard Worker        unittest_args.extend(extra_unittest_args)
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Worker    # If using pytest, replace -f with equivalent -x
437*da0073e9SAndroid Build Coastguard Worker    if options.pytest:
438*da0073e9SAndroid Build Coastguard Worker        unittest_args.extend(
439*da0073e9SAndroid Build Coastguard Worker            get_pytest_args(
440*da0073e9SAndroid Build Coastguard Worker                options,
441*da0073e9SAndroid Build Coastguard Worker                is_cpp_test=is_cpp_test,
442*da0073e9SAndroid Build Coastguard Worker                is_distributed_test=is_distributed_test,
443*da0073e9SAndroid Build Coastguard Worker            )
444*da0073e9SAndroid Build Coastguard Worker        )
445*da0073e9SAndroid Build Coastguard Worker        unittest_args.extend(test_module.get_pytest_args())
446*da0073e9SAndroid Build Coastguard Worker        replacement = {"-f": "-x"}
447*da0073e9SAndroid Build Coastguard Worker        unittest_args = [replacement.get(arg, arg) for arg in unittest_args]
448*da0073e9SAndroid Build Coastguard Worker
449*da0073e9SAndroid Build Coastguard Worker    if options.showlocals:
450*da0073e9SAndroid Build Coastguard Worker        if options.pytest:
451*da0073e9SAndroid Build Coastguard Worker            unittest_args.extend(["--showlocals", "--tb=long", "--color=yes"])
452*da0073e9SAndroid Build Coastguard Worker        else:
453*da0073e9SAndroid Build Coastguard Worker            unittest_args.append("--locals")
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Worker    # NB: These features are not available for C++ tests, but there is little incentive
456*da0073e9SAndroid Build Coastguard Worker    # to implement it because we have never seen a flaky C++ test before.
457*da0073e9SAndroid Build Coastguard Worker    if IS_CI and not is_cpp_test:
458*da0073e9SAndroid Build Coastguard Worker        ci_args = ["--import-slow-tests", "--import-disabled-tests"]
459*da0073e9SAndroid Build Coastguard Worker        if RERUN_DISABLED_TESTS:
460*da0073e9SAndroid Build Coastguard Worker            ci_args.append("--rerun-disabled-tests")
461*da0073e9SAndroid Build Coastguard Worker        # use the downloaded test cases configuration, not supported in pytest
462*da0073e9SAndroid Build Coastguard Worker        unittest_args.extend(ci_args)
463*da0073e9SAndroid Build Coastguard Worker
464*da0073e9SAndroid Build Coastguard Worker    if test_file in PYTEST_SKIP_RETRIES:
465*da0073e9SAndroid Build Coastguard Worker        if not options.pytest:
466*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
467*da0073e9SAndroid Build Coastguard Worker                "A test running without pytest cannot skip retries using "
468*da0073e9SAndroid Build Coastguard Worker                "the PYTEST_SKIP_RETRIES set."
469*da0073e9SAndroid Build Coastguard Worker            )
470*da0073e9SAndroid Build Coastguard Worker        unittest_args = [arg for arg in unittest_args if "--reruns" not in arg]
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker    # Extra arguments are not supported with pytest
473*da0073e9SAndroid Build Coastguard Worker    executable = get_executable_command(options, is_cpp_test=is_cpp_test)
474*da0073e9SAndroid Build Coastguard Worker    if not executable:
475*da0073e9SAndroid Build Coastguard Worker        # If there is no eligible executable returning here, it means an unsupported
476*da0073e9SAndroid Build Coastguard Worker        # case such as coverage for C++ test. So just returning ok makes sense
477*da0073e9SAndroid Build Coastguard Worker        return 0
478*da0073e9SAndroid Build Coastguard Worker
479*da0073e9SAndroid Build Coastguard Worker    if test_file.startswith(CPP_TEST_PREFIX):
480*da0073e9SAndroid Build Coastguard Worker        # C++ tests are not the regular test directory
481*da0073e9SAndroid Build Coastguard Worker        if CPP_TESTS_DIR:
482*da0073e9SAndroid Build Coastguard Worker            cpp_test = os.path.join(
483*da0073e9SAndroid Build Coastguard Worker                CPP_TESTS_DIR,
484*da0073e9SAndroid Build Coastguard Worker                test_file.replace(f"{CPP_TEST_PREFIX}/", ""),
485*da0073e9SAndroid Build Coastguard Worker            )
486*da0073e9SAndroid Build Coastguard Worker        else:
487*da0073e9SAndroid Build Coastguard Worker            cpp_test = os.path.join(
488*da0073e9SAndroid Build Coastguard Worker                Path(test_directory).parent,
489*da0073e9SAndroid Build Coastguard Worker                CPP_TEST_PATH,
490*da0073e9SAndroid Build Coastguard Worker                test_file.replace(f"{CPP_TEST_PREFIX}/", ""),
491*da0073e9SAndroid Build Coastguard Worker            )
492*da0073e9SAndroid Build Coastguard Worker
493*da0073e9SAndroid Build Coastguard Worker        argv = [
494*da0073e9SAndroid Build Coastguard Worker            cpp_test if sys.platform != "win32" else cpp_test + ".exe"
495*da0073e9SAndroid Build Coastguard Worker        ] + unittest_args
496*da0073e9SAndroid Build Coastguard Worker    else:
497*da0073e9SAndroid Build Coastguard Worker        # Can't call `python -m unittest test_*` here because it doesn't run code
498*da0073e9SAndroid Build Coastguard Worker        # in `if __name__ == '__main__': `. So call `python test_*.py` instead.
499*da0073e9SAndroid Build Coastguard Worker        argv = [test_file + ".py"] + unittest_args
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker    os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
502*da0073e9SAndroid Build Coastguard Worker    if options.pipe_logs:
503*da0073e9SAndroid Build Coastguard Worker        log_fd, log_path = tempfile.mkstemp(
504*da0073e9SAndroid Build Coastguard Worker            dir=REPO_ROOT / "test" / "test-reports",
505*da0073e9SAndroid Build Coastguard Worker            prefix=f"{sanitize_file_name(str(test_module))}_",
506*da0073e9SAndroid Build Coastguard Worker            suffix="_toprint.log",
507*da0073e9SAndroid Build Coastguard Worker        )
508*da0073e9SAndroid Build Coastguard Worker        os.close(log_fd)
509*da0073e9SAndroid Build Coastguard Worker
510*da0073e9SAndroid Build Coastguard Worker    command = (launcher_cmd or []) + executable + argv
511*da0073e9SAndroid Build Coastguard Worker    should_retry = (
512*da0073e9SAndroid Build Coastguard Worker        "--subprocess" not in command
513*da0073e9SAndroid Build Coastguard Worker        and not RERUN_DISABLED_TESTS
514*da0073e9SAndroid Build Coastguard Worker        and not is_cpp_test
515*da0073e9SAndroid Build Coastguard Worker        and "-n" not in command
516*da0073e9SAndroid Build Coastguard Worker    )
517*da0073e9SAndroid Build Coastguard Worker    timeout = (
518*da0073e9SAndroid Build Coastguard Worker        None
519*da0073e9SAndroid Build Coastguard Worker        if not options.enable_timeout
520*da0073e9SAndroid Build Coastguard Worker        else THRESHOLD * 6
521*da0073e9SAndroid Build Coastguard Worker        if IS_SLOW
522*da0073e9SAndroid Build Coastguard Worker        else THRESHOLD * 3
523*da0073e9SAndroid Build Coastguard Worker        if should_retry
524*da0073e9SAndroid Build Coastguard Worker        and isinstance(test_module, ShardedTest)
525*da0073e9SAndroid Build Coastguard Worker        and test_module.time is not None
526*da0073e9SAndroid Build Coastguard Worker        else THRESHOLD * 3
527*da0073e9SAndroid Build Coastguard Worker        if is_cpp_test
528*da0073e9SAndroid Build Coastguard Worker        else None
529*da0073e9SAndroid Build Coastguard Worker    )
530*da0073e9SAndroid Build Coastguard Worker    print_to_stderr(f"Executing {command} ... [{datetime.now()}]")
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Worker    with ExitStack() as stack:
533*da0073e9SAndroid Build Coastguard Worker        output = None
534*da0073e9SAndroid Build Coastguard Worker        if options.pipe_logs:
535*da0073e9SAndroid Build Coastguard Worker            output = stack.enter_context(open(log_path, "w"))
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Worker        if should_retry:
538*da0073e9SAndroid Build Coastguard Worker            ret_code, was_rerun = run_test_retries(
539*da0073e9SAndroid Build Coastguard Worker                command,
540*da0073e9SAndroid Build Coastguard Worker                test_directory,
541*da0073e9SAndroid Build Coastguard Worker                env,
542*da0073e9SAndroid Build Coastguard Worker                timeout,
543*da0073e9SAndroid Build Coastguard Worker                stepcurrent_key,
544*da0073e9SAndroid Build Coastguard Worker                output,
545*da0073e9SAndroid Build Coastguard Worker                options.continue_through_error,
546*da0073e9SAndroid Build Coastguard Worker            )
547*da0073e9SAndroid Build Coastguard Worker        else:
548*da0073e9SAndroid Build Coastguard Worker            command.extend([f"--sc={stepcurrent_key}", "--print-items"])
549*da0073e9SAndroid Build Coastguard Worker            ret_code, was_rerun = retry_shell(
550*da0073e9SAndroid Build Coastguard Worker                command,
551*da0073e9SAndroid Build Coastguard Worker                test_directory,
552*da0073e9SAndroid Build Coastguard Worker                stdout=output,
553*da0073e9SAndroid Build Coastguard Worker                stderr=output,
554*da0073e9SAndroid Build Coastguard Worker                env=env,
555*da0073e9SAndroid Build Coastguard Worker                timeout=timeout,
556*da0073e9SAndroid Build Coastguard Worker                retries=0,
557*da0073e9SAndroid Build Coastguard Worker            )
558*da0073e9SAndroid Build Coastguard Worker
559*da0073e9SAndroid Build Coastguard Worker            # Pytest return code 5 means no test is collected. Exit code 4 is
560*da0073e9SAndroid Build Coastguard Worker            # returned when the binary is not a C++ test executable, but 4 can
561*da0073e9SAndroid Build Coastguard Worker            # also be returned if the file fails before running any tests. All
562*da0073e9SAndroid Build Coastguard Worker            # binary files under build/bin that are not C++ test at the time of
563*da0073e9SAndroid Build Coastguard Worker            # this writing have been excluded and new ones should be added to
564*da0073e9SAndroid Build Coastguard Worker            # the list of exclusions in tools/testing/discover_tests.py
565*da0073e9SAndroid Build Coastguard Worker            ret_code = 0 if ret_code == 5 else ret_code
566*da0073e9SAndroid Build Coastguard Worker
567*da0073e9SAndroid Build Coastguard Worker    if options.pipe_logs and print_log:
568*da0073e9SAndroid Build Coastguard Worker        handle_log_file(
569*da0073e9SAndroid Build Coastguard Worker            test_module, log_path, failed=(ret_code != 0), was_rerun=was_rerun
570*da0073e9SAndroid Build Coastguard Worker        )
571*da0073e9SAndroid Build Coastguard Worker    return ret_code
572*da0073e9SAndroid Build Coastguard Worker
573*da0073e9SAndroid Build Coastguard Worker
574*da0073e9SAndroid Build Coastguard Workerdef try_set_cpp_stack_traces(env, command, set=True):
575*da0073e9SAndroid Build Coastguard Worker    # Print full c++ stack traces during retries
576*da0073e9SAndroid Build Coastguard Worker    env = env or {}
577*da0073e9SAndroid Build Coastguard Worker    env["TORCH_SHOW_CPP_STACKTRACES"] = "1" if set else "0"
578*da0073e9SAndroid Build Coastguard Worker    return env
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Worker
581*da0073e9SAndroid Build Coastguard Workerdef run_test_retries(
582*da0073e9SAndroid Build Coastguard Worker    command,
583*da0073e9SAndroid Build Coastguard Worker    test_directory,
584*da0073e9SAndroid Build Coastguard Worker    env,
585*da0073e9SAndroid Build Coastguard Worker    timeout,
586*da0073e9SAndroid Build Coastguard Worker    stepcurrent_key,
587*da0073e9SAndroid Build Coastguard Worker    output,
588*da0073e9SAndroid Build Coastguard Worker    continue_through_error,
589*da0073e9SAndroid Build Coastguard Worker):
590*da0073e9SAndroid Build Coastguard Worker    # Run the test with -x to stop at first failure.  Rerun the test by itself.
591*da0073e9SAndroid Build Coastguard Worker    # If it succeeds, move on to the rest of the tests in a new process.  If it
592*da0073e9SAndroid Build Coastguard Worker    # still fails, see below
593*da0073e9SAndroid Build Coastguard Worker    #
594*da0073e9SAndroid Build Coastguard Worker    # If continue through error is not set, then we fail fast.
595*da0073e9SAndroid Build Coastguard Worker    #
596*da0073e9SAndroid Build Coastguard Worker    # If continue through error is set, then we skip that test, and keep going.
597*da0073e9SAndroid Build Coastguard Worker    # Basically if the same test fails 3 times in a row, skip the test on the
598*da0073e9SAndroid Build Coastguard Worker    # next run, but still fail in the end. I take advantage of the value saved
599*da0073e9SAndroid Build Coastguard Worker    # in stepcurrent to keep track of the most recently run test (which is the
600*da0073e9SAndroid Build Coastguard Worker    # one that failed if there was a failure).
601*da0073e9SAndroid Build Coastguard Worker
602*da0073e9SAndroid Build Coastguard Worker    def print_to_file(s):
603*da0073e9SAndroid Build Coastguard Worker        print(s, file=output, flush=True)
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker    num_failures = defaultdict(int)
606*da0073e9SAndroid Build Coastguard Worker
607*da0073e9SAndroid Build Coastguard Worker    print_items = ["--print-items"]
608*da0073e9SAndroid Build Coastguard Worker    sc_command = f"--sc={stepcurrent_key}"
609*da0073e9SAndroid Build Coastguard Worker    while True:
610*da0073e9SAndroid Build Coastguard Worker        ret_code, _ = retry_shell(
611*da0073e9SAndroid Build Coastguard Worker            command + [sc_command] + print_items,
612*da0073e9SAndroid Build Coastguard Worker            test_directory,
613*da0073e9SAndroid Build Coastguard Worker            stdout=output,
614*da0073e9SAndroid Build Coastguard Worker            stderr=output,
615*da0073e9SAndroid Build Coastguard Worker            env=env,
616*da0073e9SAndroid Build Coastguard Worker            timeout=timeout,
617*da0073e9SAndroid Build Coastguard Worker            retries=0,  # no retries here, we do it ourselves, this is because it handles timeout exceptions well
618*da0073e9SAndroid Build Coastguard Worker        )
619*da0073e9SAndroid Build Coastguard Worker        ret_code = 0 if ret_code == 5 else ret_code
620*da0073e9SAndroid Build Coastguard Worker        if ret_code == 0 and not sc_command.startswith("--rs="):
621*da0073e9SAndroid Build Coastguard Worker            break  # Got to the end of the test suite successfully
622*da0073e9SAndroid Build Coastguard Worker        signal_name = f" ({SIGNALS_TO_NAMES_DICT[-ret_code]})" if ret_code < 0 else ""
623*da0073e9SAndroid Build Coastguard Worker        print_to_file(f"Got exit code {ret_code}{signal_name}")
624*da0073e9SAndroid Build Coastguard Worker
625*da0073e9SAndroid Build Coastguard Worker        # Read what just failed/ran
626*da0073e9SAndroid Build Coastguard Worker        try:
627*da0073e9SAndroid Build Coastguard Worker            with open(
628*da0073e9SAndroid Build Coastguard Worker                REPO_ROOT / ".pytest_cache/v/cache/stepcurrent" / stepcurrent_key
629*da0073e9SAndroid Build Coastguard Worker            ) as f:
630*da0073e9SAndroid Build Coastguard Worker                current_failure = f.read()
631*da0073e9SAndroid Build Coastguard Worker        except FileNotFoundError:
632*da0073e9SAndroid Build Coastguard Worker            print_to_file(
633*da0073e9SAndroid Build Coastguard Worker                "No stepcurrent file found. Either pytest didn't get to run (e.g. import error)"
634*da0073e9SAndroid Build Coastguard Worker                + " or file got deleted (contact dev infra)"
635*da0073e9SAndroid Build Coastguard Worker            )
636*da0073e9SAndroid Build Coastguard Worker            break
637*da0073e9SAndroid Build Coastguard Worker
638*da0073e9SAndroid Build Coastguard Worker        env = try_set_cpp_stack_traces(env, command, set=False)
639*da0073e9SAndroid Build Coastguard Worker        if ret_code != 0:
640*da0073e9SAndroid Build Coastguard Worker            num_failures[current_failure] += 1
641*da0073e9SAndroid Build Coastguard Worker
642*da0073e9SAndroid Build Coastguard Worker        if ret_code == 0:
643*da0073e9SAndroid Build Coastguard Worker            # Rerunning the previously failing test succeeded, so now we can
644*da0073e9SAndroid Build Coastguard Worker            # skip it and move on
645*da0073e9SAndroid Build Coastguard Worker            sc_command = f"--scs={stepcurrent_key}"
646*da0073e9SAndroid Build Coastguard Worker            print_to_file(
647*da0073e9SAndroid Build Coastguard Worker                "Test succeeeded in new process, continuing with the rest of the tests"
648*da0073e9SAndroid Build Coastguard Worker            )
649*da0073e9SAndroid Build Coastguard Worker        elif num_failures[current_failure] >= 3:
650*da0073e9SAndroid Build Coastguard Worker            if not continue_through_error:
651*da0073e9SAndroid Build Coastguard Worker                print_to_file("Stopping at first consistent failure")
652*da0073e9SAndroid Build Coastguard Worker                break
653*da0073e9SAndroid Build Coastguard Worker            sc_command = f"--scs={stepcurrent_key}"
654*da0073e9SAndroid Build Coastguard Worker            print_to_file(
655*da0073e9SAndroid Build Coastguard Worker                "Test failed consistently, "
656*da0073e9SAndroid Build Coastguard Worker                "continuing with the rest of the tests due to continue-through-error being set"
657*da0073e9SAndroid Build Coastguard Worker            )
658*da0073e9SAndroid Build Coastguard Worker        else:
659*da0073e9SAndroid Build Coastguard Worker            env = try_set_cpp_stack_traces(env, command, set=True)
660*da0073e9SAndroid Build Coastguard Worker            sc_command = f"--rs={stepcurrent_key}"
661*da0073e9SAndroid Build Coastguard Worker            print_to_file("Retrying single test...")
662*da0073e9SAndroid Build Coastguard Worker        print_items = []  # do not continue printing them, massive waste of space
663*da0073e9SAndroid Build Coastguard Worker
664*da0073e9SAndroid Build Coastguard Worker    consistent_failures = [x[1:-1] for x in num_failures.keys() if num_failures[x] >= 3]
665*da0073e9SAndroid Build Coastguard Worker    flaky_failures = [x[1:-1] for x in num_failures.keys() if 0 < num_failures[x] < 3]
666*da0073e9SAndroid Build Coastguard Worker    if len(flaky_failures) > 0:
667*da0073e9SAndroid Build Coastguard Worker        print_to_file(
668*da0073e9SAndroid Build Coastguard Worker            "The following tests failed and then succeeded when run in a new process"
669*da0073e9SAndroid Build Coastguard Worker            + f"{flaky_failures}",
670*da0073e9SAndroid Build Coastguard Worker        )
671*da0073e9SAndroid Build Coastguard Worker    if len(consistent_failures) > 0:
672*da0073e9SAndroid Build Coastguard Worker        print_to_file(f"The following tests failed consistently: {consistent_failures}")
673*da0073e9SAndroid Build Coastguard Worker        return 1, True
674*da0073e9SAndroid Build Coastguard Worker    return ret_code, any(x > 0 for x in num_failures.values())
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker
677*da0073e9SAndroid Build Coastguard Workerdef run_test_with_subprocess(test_module, test_directory, options):
678*da0073e9SAndroid Build Coastguard Worker    return run_test(
679*da0073e9SAndroid Build Coastguard Worker        test_module, test_directory, options, extra_unittest_args=["--subprocess"]
680*da0073e9SAndroid Build Coastguard Worker    )
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Worker
683*da0073e9SAndroid Build Coastguard Workerdef _test_cpp_extensions_aot(test_directory, options, use_ninja):
684*da0073e9SAndroid Build Coastguard Worker    if use_ninja:
685*da0073e9SAndroid Build Coastguard Worker        try:
686*da0073e9SAndroid Build Coastguard Worker            from torch.utils import cpp_extension
687*da0073e9SAndroid Build Coastguard Worker
688*da0073e9SAndroid Build Coastguard Worker            cpp_extension.verify_ninja_availability()
689*da0073e9SAndroid Build Coastguard Worker        except RuntimeError:
690*da0073e9SAndroid Build Coastguard Worker            print_to_stderr(CPP_EXTENSIONS_ERROR)
691*da0073e9SAndroid Build Coastguard Worker            return 1
692*da0073e9SAndroid Build Coastguard Worker
693*da0073e9SAndroid Build Coastguard Worker    # Wipe the build folder, if it exists already
694*da0073e9SAndroid Build Coastguard Worker    cpp_extensions_test_dir = os.path.join(test_directory, "cpp_extensions")
695*da0073e9SAndroid Build Coastguard Worker    cpp_extensions_test_build_dir = os.path.join(cpp_extensions_test_dir, "build")
696*da0073e9SAndroid Build Coastguard Worker    if os.path.exists(cpp_extensions_test_build_dir):
697*da0073e9SAndroid Build Coastguard Worker        shutil.rmtree(cpp_extensions_test_build_dir)
698*da0073e9SAndroid Build Coastguard Worker
699*da0073e9SAndroid Build Coastguard Worker    # Build the test cpp extensions modules
700*da0073e9SAndroid Build Coastguard Worker    shell_env = os.environ.copy()
701*da0073e9SAndroid Build Coastguard Worker    shell_env["USE_NINJA"] = str(1 if use_ninja else 0)
702*da0073e9SAndroid Build Coastguard Worker    cmd = [sys.executable, "setup.py", "install", "--root", "./install"]
703*da0073e9SAndroid Build Coastguard Worker    return_code = shell(cmd, cwd=cpp_extensions_test_dir, env=shell_env)
704*da0073e9SAndroid Build Coastguard Worker    if return_code != 0:
705*da0073e9SAndroid Build Coastguard Worker        return return_code
706*da0073e9SAndroid Build Coastguard Worker    if sys.platform != "win32":
707*da0073e9SAndroid Build Coastguard Worker        return_code = shell(
708*da0073e9SAndroid Build Coastguard Worker            cmd,
709*da0073e9SAndroid Build Coastguard Worker            cwd=os.path.join(cpp_extensions_test_dir, "no_python_abi_suffix_test"),
710*da0073e9SAndroid Build Coastguard Worker            env=shell_env,
711*da0073e9SAndroid Build Coastguard Worker        )
712*da0073e9SAndroid Build Coastguard Worker        if return_code != 0:
713*da0073e9SAndroid Build Coastguard Worker            return return_code
714*da0073e9SAndroid Build Coastguard Worker
715*da0073e9SAndroid Build Coastguard Worker    # "install" the test modules and run tests
716*da0073e9SAndroid Build Coastguard Worker    python_path = os.environ.get("PYTHONPATH", "")
717*da0073e9SAndroid Build Coastguard Worker    from shutil import copyfile
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Worker    os.environ["USE_NINJA"] = shell_env["USE_NINJA"]
720*da0073e9SAndroid Build Coastguard Worker    test_module = "test_cpp_extensions_aot" + ("_ninja" if use_ninja else "_no_ninja")
721*da0073e9SAndroid Build Coastguard Worker    copyfile(
722*da0073e9SAndroid Build Coastguard Worker        test_directory + "/test_cpp_extensions_aot.py",
723*da0073e9SAndroid Build Coastguard Worker        test_directory + "/" + test_module + ".py",
724*da0073e9SAndroid Build Coastguard Worker    )
725*da0073e9SAndroid Build Coastguard Worker    try:
726*da0073e9SAndroid Build Coastguard Worker        cpp_extensions = os.path.join(test_directory, "cpp_extensions")
727*da0073e9SAndroid Build Coastguard Worker        install_directory = ""
728*da0073e9SAndroid Build Coastguard Worker        # install directory is the one that is named site-packages
729*da0073e9SAndroid Build Coastguard Worker        for root, directories, _ in os.walk(os.path.join(cpp_extensions, "install")):
730*da0073e9SAndroid Build Coastguard Worker            for directory in directories:
731*da0073e9SAndroid Build Coastguard Worker                if "-packages" in directory:
732*da0073e9SAndroid Build Coastguard Worker                    install_directory = os.path.join(root, directory)
733*da0073e9SAndroid Build Coastguard Worker
734*da0073e9SAndroid Build Coastguard Worker        assert install_directory, "install_directory must not be empty"
735*da0073e9SAndroid Build Coastguard Worker        os.environ["PYTHONPATH"] = os.pathsep.join([install_directory, python_path])
736*da0073e9SAndroid Build Coastguard Worker        return run_test(ShardedTest(test_module, 1, 1), test_directory, options)
737*da0073e9SAndroid Build Coastguard Worker    finally:
738*da0073e9SAndroid Build Coastguard Worker        os.environ["PYTHONPATH"] = python_path
739*da0073e9SAndroid Build Coastguard Worker        if os.path.exists(test_directory + "/" + test_module + ".py"):
740*da0073e9SAndroid Build Coastguard Worker            os.remove(test_directory + "/" + test_module + ".py")
741*da0073e9SAndroid Build Coastguard Worker        os.environ.pop("USE_NINJA")
742*da0073e9SAndroid Build Coastguard Worker
743*da0073e9SAndroid Build Coastguard Worker
744*da0073e9SAndroid Build Coastguard Workerdef test_cpp_extensions_aot_ninja(test_module, test_directory, options):
745*da0073e9SAndroid Build Coastguard Worker    return _test_cpp_extensions_aot(test_directory, options, use_ninja=True)
746*da0073e9SAndroid Build Coastguard Worker
747*da0073e9SAndroid Build Coastguard Worker
748*da0073e9SAndroid Build Coastguard Workerdef test_cpp_extensions_aot_no_ninja(test_module, test_directory, options):
749*da0073e9SAndroid Build Coastguard Worker    return _test_cpp_extensions_aot(test_directory, options, use_ninja=False)
750*da0073e9SAndroid Build Coastguard Worker
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Workerdef test_autoload_enable(test_module, test_directory, options):
753*da0073e9SAndroid Build Coastguard Worker    return _test_autoload(test_directory, options, enable=True)
754*da0073e9SAndroid Build Coastguard Worker
755*da0073e9SAndroid Build Coastguard Worker
756*da0073e9SAndroid Build Coastguard Workerdef test_autoload_disable(test_module, test_directory, options):
757*da0073e9SAndroid Build Coastguard Worker    return _test_autoload(test_directory, options, enable=False)
758*da0073e9SAndroid Build Coastguard Worker
759*da0073e9SAndroid Build Coastguard Worker
760*da0073e9SAndroid Build Coastguard Workerdef _test_autoload(test_directory, options, enable=True):
761*da0073e9SAndroid Build Coastguard Worker    # Wipe the build folder, if it exists already
762*da0073e9SAndroid Build Coastguard Worker    cpp_extensions_test_dir = os.path.join(test_directory, "cpp_extensions")
763*da0073e9SAndroid Build Coastguard Worker    cpp_extensions_test_build_dir = os.path.join(cpp_extensions_test_dir, "build")
764*da0073e9SAndroid Build Coastguard Worker    if os.path.exists(cpp_extensions_test_build_dir):
765*da0073e9SAndroid Build Coastguard Worker        shutil.rmtree(cpp_extensions_test_build_dir)
766*da0073e9SAndroid Build Coastguard Worker
767*da0073e9SAndroid Build Coastguard Worker    # Build the test cpp extensions modules
768*da0073e9SAndroid Build Coastguard Worker    cmd = [sys.executable, "setup.py", "install", "--root", "./install"]
769*da0073e9SAndroid Build Coastguard Worker    return_code = shell(cmd, cwd=cpp_extensions_test_dir, env=os.environ)
770*da0073e9SAndroid Build Coastguard Worker    if return_code != 0:
771*da0073e9SAndroid Build Coastguard Worker        return return_code
772*da0073e9SAndroid Build Coastguard Worker
773*da0073e9SAndroid Build Coastguard Worker    # "install" the test modules and run tests
774*da0073e9SAndroid Build Coastguard Worker    python_path = os.environ.get("PYTHONPATH", "")
775*da0073e9SAndroid Build Coastguard Worker
776*da0073e9SAndroid Build Coastguard Worker    try:
777*da0073e9SAndroid Build Coastguard Worker        cpp_extensions = os.path.join(test_directory, "cpp_extensions")
778*da0073e9SAndroid Build Coastguard Worker        install_directory = ""
779*da0073e9SAndroid Build Coastguard Worker        # install directory is the one that is named site-packages
780*da0073e9SAndroid Build Coastguard Worker        for root, directories, _ in os.walk(os.path.join(cpp_extensions, "install")):
781*da0073e9SAndroid Build Coastguard Worker            for directory in directories:
782*da0073e9SAndroid Build Coastguard Worker                if "-packages" in directory:
783*da0073e9SAndroid Build Coastguard Worker                    install_directory = os.path.join(root, directory)
784*da0073e9SAndroid Build Coastguard Worker
785*da0073e9SAndroid Build Coastguard Worker        assert install_directory, "install_directory must not be empty"
786*da0073e9SAndroid Build Coastguard Worker        os.environ["PYTHONPATH"] = os.pathsep.join([install_directory, python_path])
787*da0073e9SAndroid Build Coastguard Worker        os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = str(int(enable))
788*da0073e9SAndroid Build Coastguard Worker
789*da0073e9SAndroid Build Coastguard Worker        cmd = [sys.executable, "test_autoload.py"]
790*da0073e9SAndroid Build Coastguard Worker        return_code = shell(cmd, cwd=test_directory, env=os.environ)
791*da0073e9SAndroid Build Coastguard Worker        return return_code
792*da0073e9SAndroid Build Coastguard Worker    finally:
793*da0073e9SAndroid Build Coastguard Worker        os.environ["PYTHONPATH"] = python_path
794*da0073e9SAndroid Build Coastguard Worker        os.environ.pop("TORCH_DEVICE_BACKEND_AUTOLOAD")
795*da0073e9SAndroid Build Coastguard Worker
796*da0073e9SAndroid Build Coastguard Worker
797*da0073e9SAndroid Build Coastguard Workerdef test_distributed(test_module, test_directory, options):
798*da0073e9SAndroid Build Coastguard Worker    # MPI tests are broken with Python-3.9
799*da0073e9SAndroid Build Coastguard Worker    mpi_available = subprocess.call(
800*da0073e9SAndroid Build Coastguard Worker        "command -v mpiexec", shell=True
801*da0073e9SAndroid Build Coastguard Worker    ) == 0 and sys.version_info < (3, 9)
802*da0073e9SAndroid Build Coastguard Worker    if options.verbose and not mpi_available:
803*da0073e9SAndroid Build Coastguard Worker        print_to_stderr("MPI not available -- MPI backend tests will be skipped")
804*da0073e9SAndroid Build Coastguard Worker
805*da0073e9SAndroid Build Coastguard Worker    config = DISTRIBUTED_TESTS_CONFIG
806*da0073e9SAndroid Build Coastguard Worker    for backend, env_vars in config.items():
807*da0073e9SAndroid Build Coastguard Worker        if sys.platform == "win32" and backend != "gloo":
808*da0073e9SAndroid Build Coastguard Worker            continue
809*da0073e9SAndroid Build Coastguard Worker        if backend == "mpi" and not mpi_available:
810*da0073e9SAndroid Build Coastguard Worker            continue
811*da0073e9SAndroid Build Coastguard Worker        for with_init_file in {True, False}:
812*da0073e9SAndroid Build Coastguard Worker            if sys.platform == "win32" and not with_init_file:
813*da0073e9SAndroid Build Coastguard Worker                continue
814*da0073e9SAndroid Build Coastguard Worker            tmp_dir = tempfile.mkdtemp()
815*da0073e9SAndroid Build Coastguard Worker            if options.verbose:
816*da0073e9SAndroid Build Coastguard Worker                init_str = "with {} init_method"
817*da0073e9SAndroid Build Coastguard Worker                with_init = init_str.format("file" if with_init_file else "env")
818*da0073e9SAndroid Build Coastguard Worker                print_to_stderr(
819*da0073e9SAndroid Build Coastguard Worker                    f"Running distributed tests for the {backend} backend {with_init}"
820*da0073e9SAndroid Build Coastguard Worker                )
821*da0073e9SAndroid Build Coastguard Worker            old_environ = dict(os.environ)
822*da0073e9SAndroid Build Coastguard Worker            os.environ["TEMP_DIR"] = tmp_dir
823*da0073e9SAndroid Build Coastguard Worker            os.environ["BACKEND"] = backend
824*da0073e9SAndroid Build Coastguard Worker            os.environ.update(env_vars)
825*da0073e9SAndroid Build Coastguard Worker            try:
826*da0073e9SAndroid Build Coastguard Worker                os.mkdir(os.path.join(tmp_dir, "barrier"))
827*da0073e9SAndroid Build Coastguard Worker                os.mkdir(os.path.join(tmp_dir, "test_dir"))
828*da0073e9SAndroid Build Coastguard Worker                if backend == "mpi":
829*da0073e9SAndroid Build Coastguard Worker                    # test mpiexec for --noprefix option
830*da0073e9SAndroid Build Coastguard Worker                    with open(os.devnull, "w") as devnull:
831*da0073e9SAndroid Build Coastguard Worker                        allowrunasroot_opt = (
832*da0073e9SAndroid Build Coastguard Worker                            "--allow-run-as-root"
833*da0073e9SAndroid Build Coastguard Worker                            if subprocess.call(
834*da0073e9SAndroid Build Coastguard Worker                                'mpiexec --allow-run-as-root -n 1 bash -c ""',
835*da0073e9SAndroid Build Coastguard Worker                                shell=True,
836*da0073e9SAndroid Build Coastguard Worker                                stdout=devnull,
837*da0073e9SAndroid Build Coastguard Worker                                stderr=subprocess.STDOUT,
838*da0073e9SAndroid Build Coastguard Worker                            )
839*da0073e9SAndroid Build Coastguard Worker                            == 0
840*da0073e9SAndroid Build Coastguard Worker                            else ""
841*da0073e9SAndroid Build Coastguard Worker                        )
842*da0073e9SAndroid Build Coastguard Worker                        noprefix_opt = (
843*da0073e9SAndroid Build Coastguard Worker                            "--noprefix"
844*da0073e9SAndroid Build Coastguard Worker                            if subprocess.call(
845*da0073e9SAndroid Build Coastguard Worker                                f'mpiexec {allowrunasroot_opt} -n 1 --noprefix bash -c ""',
846*da0073e9SAndroid Build Coastguard Worker                                shell=True,
847*da0073e9SAndroid Build Coastguard Worker                                stdout=devnull,
848*da0073e9SAndroid Build Coastguard Worker                                stderr=subprocess.STDOUT,
849*da0073e9SAndroid Build Coastguard Worker                            )
850*da0073e9SAndroid Build Coastguard Worker                            == 0
851*da0073e9SAndroid Build Coastguard Worker                            else ""
852*da0073e9SAndroid Build Coastguard Worker                        )
853*da0073e9SAndroid Build Coastguard Worker
854*da0073e9SAndroid Build Coastguard Worker                    mpiexec = ["mpiexec", "-n", "3", noprefix_opt, allowrunasroot_opt]
855*da0073e9SAndroid Build Coastguard Worker
856*da0073e9SAndroid Build Coastguard Worker                    return_code = run_test(
857*da0073e9SAndroid Build Coastguard Worker                        test_module, test_directory, options, launcher_cmd=mpiexec
858*da0073e9SAndroid Build Coastguard Worker                    )
859*da0073e9SAndroid Build Coastguard Worker                else:
860*da0073e9SAndroid Build Coastguard Worker                    return_code = run_test(
861*da0073e9SAndroid Build Coastguard Worker                        test_module,
862*da0073e9SAndroid Build Coastguard Worker                        test_directory,
863*da0073e9SAndroid Build Coastguard Worker                        options,
864*da0073e9SAndroid Build Coastguard Worker                        extra_unittest_args=["--subprocess"],
865*da0073e9SAndroid Build Coastguard Worker                    )
866*da0073e9SAndroid Build Coastguard Worker                if return_code != 0:
867*da0073e9SAndroid Build Coastguard Worker                    return return_code
868*da0073e9SAndroid Build Coastguard Worker            finally:
869*da0073e9SAndroid Build Coastguard Worker                shutil.rmtree(tmp_dir)
870*da0073e9SAndroid Build Coastguard Worker                os.environ.clear()
871*da0073e9SAndroid Build Coastguard Worker                os.environ.update(old_environ)
872*da0073e9SAndroid Build Coastguard Worker    return 0
873*da0073e9SAndroid Build Coastguard Worker
874*da0073e9SAndroid Build Coastguard Worker
875*da0073e9SAndroid Build Coastguard Workerdef run_doctests(test_module, test_directory, options):
876*da0073e9SAndroid Build Coastguard Worker    """
877*da0073e9SAndroid Build Coastguard Worker    Assumes the incoming test module is called doctest, and simply executes the
878*da0073e9SAndroid Build Coastguard Worker    xdoctest runner on the torch library itself.
879*da0073e9SAndroid Build Coastguard Worker    """
880*da0073e9SAndroid Build Coastguard Worker    import xdoctest
881*da0073e9SAndroid Build Coastguard Worker
882*da0073e9SAndroid Build Coastguard Worker    pkgpath = Path(torch.__file__).parent
883*da0073e9SAndroid Build Coastguard Worker
884*da0073e9SAndroid Build Coastguard Worker    exclude_module_list = ["torch._vendor.*"]
885*da0073e9SAndroid Build Coastguard Worker    enabled = {
886*da0073e9SAndroid Build Coastguard Worker        # TODO: expose these options to the user
887*da0073e9SAndroid Build Coastguard Worker        # For now disable all feature-conditional tests
888*da0073e9SAndroid Build Coastguard Worker        # 'lapack': 'auto',
889*da0073e9SAndroid Build Coastguard Worker        # 'cuda': 'auto',
890*da0073e9SAndroid Build Coastguard Worker        # 'cuda1': 'auto',
891*da0073e9SAndroid Build Coastguard Worker        # 'qengine': 'auto',
892*da0073e9SAndroid Build Coastguard Worker        "lapack": 0,
893*da0073e9SAndroid Build Coastguard Worker        "cuda": 0,
894*da0073e9SAndroid Build Coastguard Worker        "cuda1": 0,
895*da0073e9SAndroid Build Coastguard Worker        "qengine": 0,
896*da0073e9SAndroid Build Coastguard Worker        "autograd_profiler": 0,
897*da0073e9SAndroid Build Coastguard Worker        "cpp_ext": 0,
898*da0073e9SAndroid Build Coastguard Worker        "monitor": 0,
899*da0073e9SAndroid Build Coastguard Worker        "onnx": "auto",
900*da0073e9SAndroid Build Coastguard Worker    }
901*da0073e9SAndroid Build Coastguard Worker
902*da0073e9SAndroid Build Coastguard Worker    # Resolve "auto" based on a test to determine if the feature is available.
903*da0073e9SAndroid Build Coastguard Worker    if enabled["cuda"] == "auto" and torch.cuda.is_available():
904*da0073e9SAndroid Build Coastguard Worker        enabled["cuda"] = True
905*da0073e9SAndroid Build Coastguard Worker
906*da0073e9SAndroid Build Coastguard Worker    if (
907*da0073e9SAndroid Build Coastguard Worker        enabled["cuda1"] == "auto"
908*da0073e9SAndroid Build Coastguard Worker        and torch.cuda.is_available()
909*da0073e9SAndroid Build Coastguard Worker        and torch.cuda.device_count() > 1
910*da0073e9SAndroid Build Coastguard Worker    ):
911*da0073e9SAndroid Build Coastguard Worker        enabled["cuda1"] = True
912*da0073e9SAndroid Build Coastguard Worker
913*da0073e9SAndroid Build Coastguard Worker    if enabled["lapack"] == "auto" and torch._C.has_lapack:
914*da0073e9SAndroid Build Coastguard Worker        enabled["lapack"] = True
915*da0073e9SAndroid Build Coastguard Worker
916*da0073e9SAndroid Build Coastguard Worker    if enabled["qengine"] == "auto":
917*da0073e9SAndroid Build Coastguard Worker        try:
918*da0073e9SAndroid Build Coastguard Worker            # Is there a better check if quantization is enabled?
919*da0073e9SAndroid Build Coastguard Worker            import torch.ao.nn.quantized as nnq  # NOQA: F401
920*da0073e9SAndroid Build Coastguard Worker
921*da0073e9SAndroid Build Coastguard Worker            torch.backends.quantized.engine = "qnnpack"
922*da0073e9SAndroid Build Coastguard Worker            torch.backends.quantized.engine = "fbgemm"
923*da0073e9SAndroid Build Coastguard Worker        except (ImportError, RuntimeError):
924*da0073e9SAndroid Build Coastguard Worker            ...
925*da0073e9SAndroid Build Coastguard Worker        else:
926*da0073e9SAndroid Build Coastguard Worker            enabled["qengine"] = True
927*da0073e9SAndroid Build Coastguard Worker
928*da0073e9SAndroid Build Coastguard Worker    if enabled["onnx"] == "auto":
929*da0073e9SAndroid Build Coastguard Worker        try:
930*da0073e9SAndroid Build Coastguard Worker            import onnx  # NOQA: F401
931*da0073e9SAndroid Build Coastguard Worker            import onnxruntime  # NOQA: F401
932*da0073e9SAndroid Build Coastguard Worker            import onnxscript  # NOQA: F401
933*da0073e9SAndroid Build Coastguard Worker        except ImportError:
934*da0073e9SAndroid Build Coastguard Worker            exclude_module_list.append("torch.onnx.*")
935*da0073e9SAndroid Build Coastguard Worker            enabled["onnx"] = False
936*da0073e9SAndroid Build Coastguard Worker        else:
937*da0073e9SAndroid Build Coastguard Worker            enabled["onnx"] = True
938*da0073e9SAndroid Build Coastguard Worker
939*da0073e9SAndroid Build Coastguard Worker    # Set doctest environment variables
940*da0073e9SAndroid Build Coastguard Worker    if enabled["cuda"]:
941*da0073e9SAndroid Build Coastguard Worker        os.environ["TORCH_DOCTEST_CUDA"] = "1"
942*da0073e9SAndroid Build Coastguard Worker
943*da0073e9SAndroid Build Coastguard Worker    if enabled["cuda1"]:
944*da0073e9SAndroid Build Coastguard Worker        os.environ["TORCH_DOCTEST_CUDA1"] = "1"
945*da0073e9SAndroid Build Coastguard Worker
946*da0073e9SAndroid Build Coastguard Worker    if enabled["lapack"]:
947*da0073e9SAndroid Build Coastguard Worker        os.environ["TORCH_DOCTEST_LAPACK"] = "1"
948*da0073e9SAndroid Build Coastguard Worker
949*da0073e9SAndroid Build Coastguard Worker    if enabled["qengine"]:
950*da0073e9SAndroid Build Coastguard Worker        os.environ["TORCH_DOCTEST_QENGINE"] = "1"
951*da0073e9SAndroid Build Coastguard Worker
952*da0073e9SAndroid Build Coastguard Worker    if enabled["autograd_profiler"]:
953*da0073e9SAndroid Build Coastguard Worker        os.environ["TORCH_DOCTEST_AUTOGRAD_PROFILER"] = "1"
954*da0073e9SAndroid Build Coastguard Worker
955*da0073e9SAndroid Build Coastguard Worker    if enabled["cpp_ext"]:
956*da0073e9SAndroid Build Coastguard Worker        os.environ["TORCH_DOCTEST_CPP_EXT"] = "1"
957*da0073e9SAndroid Build Coastguard Worker
958*da0073e9SAndroid Build Coastguard Worker    if enabled["monitor"]:
959*da0073e9SAndroid Build Coastguard Worker        os.environ["TORCH_DOCTEST_MONITOR"] = "1"
960*da0073e9SAndroid Build Coastguard Worker
961*da0073e9SAndroid Build Coastguard Worker    if enabled["onnx"]:
962*da0073e9SAndroid Build Coastguard Worker        os.environ["TORCH_DOCTEST_ONNX"] = "1"
963*da0073e9SAndroid Build Coastguard Worker
964*da0073e9SAndroid Build Coastguard Worker    if 0:
965*da0073e9SAndroid Build Coastguard Worker        # TODO: could try to enable some of these
966*da0073e9SAndroid Build Coastguard Worker        os.environ["TORCH_DOCTEST_QUANTIZED_DYNAMIC"] = "1"
967*da0073e9SAndroid Build Coastguard Worker        os.environ["TORCH_DOCTEST_ANOMALY"] = "1"
968*da0073e9SAndroid Build Coastguard Worker        os.environ["TORCH_DOCTEST_AUTOGRAD"] = "1"
969*da0073e9SAndroid Build Coastguard Worker        os.environ["TORCH_DOCTEST_HUB"] = "1"
970*da0073e9SAndroid Build Coastguard Worker        os.environ["TORCH_DOCTEST_DATALOADER"] = "1"
971*da0073e9SAndroid Build Coastguard Worker        os.environ["TORCH_DOCTEST_FUTURES"] = "1"
972*da0073e9SAndroid Build Coastguard Worker
973*da0073e9SAndroid Build Coastguard Worker    pkgpath = os.path.dirname(torch.__file__)
974*da0073e9SAndroid Build Coastguard Worker
975*da0073e9SAndroid Build Coastguard Worker    xdoctest_config = {
976*da0073e9SAndroid Build Coastguard Worker        "global_exec": r"\n".join(
977*da0073e9SAndroid Build Coastguard Worker            [
978*da0073e9SAndroid Build Coastguard Worker                "from torch import nn",
979*da0073e9SAndroid Build Coastguard Worker                "import torch.nn.functional as F",
980*da0073e9SAndroid Build Coastguard Worker                "import torch",
981*da0073e9SAndroid Build Coastguard Worker            ]
982*da0073e9SAndroid Build Coastguard Worker        ),
983*da0073e9SAndroid Build Coastguard Worker        "analysis": "static",  # set to "auto" to test doctests in compiled modules
984*da0073e9SAndroid Build Coastguard Worker        "style": "google",
985*da0073e9SAndroid Build Coastguard Worker        "options": "+IGNORE_WHITESPACE",
986*da0073e9SAndroid Build Coastguard Worker    }
987*da0073e9SAndroid Build Coastguard Worker    xdoctest_verbose = max(1, options.verbose)
988*da0073e9SAndroid Build Coastguard Worker    run_summary = xdoctest.runner.doctest_module(
989*da0073e9SAndroid Build Coastguard Worker        os.fspath(pkgpath),
990*da0073e9SAndroid Build Coastguard Worker        config=xdoctest_config,
991*da0073e9SAndroid Build Coastguard Worker        verbose=xdoctest_verbose,
992*da0073e9SAndroid Build Coastguard Worker        command=options.xdoctest_command,
993*da0073e9SAndroid Build Coastguard Worker        argv=[],
994*da0073e9SAndroid Build Coastguard Worker        exclude=exclude_module_list,
995*da0073e9SAndroid Build Coastguard Worker    )
996*da0073e9SAndroid Build Coastguard Worker    result = 1 if run_summary.get("n_failed", 0) else 0
997*da0073e9SAndroid Build Coastguard Worker    return result
998*da0073e9SAndroid Build Coastguard Worker
999*da0073e9SAndroid Build Coastguard Worker
1000*da0073e9SAndroid Build Coastguard Workerdef sanitize_file_name(file: str):
1001*da0073e9SAndroid Build Coastguard Worker    return file.replace("\\", ".").replace("/", ".").replace(" ", "_")
1002*da0073e9SAndroid Build Coastguard Worker
1003*da0073e9SAndroid Build Coastguard Worker
1004*da0073e9SAndroid Build Coastguard Workerdef handle_log_file(
1005*da0073e9SAndroid Build Coastguard Worker    test: ShardedTest, file_path: str, failed: bool, was_rerun: bool
1006*da0073e9SAndroid Build Coastguard Worker) -> None:
1007*da0073e9SAndroid Build Coastguard Worker    test = str(test)
1008*da0073e9SAndroid Build Coastguard Worker    with open(file_path, errors="ignore") as f:
1009*da0073e9SAndroid Build Coastguard Worker        full_text = f.read()
1010*da0073e9SAndroid Build Coastguard Worker
1011*da0073e9SAndroid Build Coastguard Worker    new_file = "test/test-reports/" + sanitize_file_name(
1012*da0073e9SAndroid Build Coastguard Worker        f"{test}_{os.urandom(8).hex()}_.log"
1013*da0073e9SAndroid Build Coastguard Worker    )
1014*da0073e9SAndroid Build Coastguard Worker    os.rename(file_path, REPO_ROOT / new_file)
1015*da0073e9SAndroid Build Coastguard Worker
1016*da0073e9SAndroid Build Coastguard Worker    if not failed and not was_rerun and "=== RERUNS ===" not in full_text:
1017*da0073e9SAndroid Build Coastguard Worker        # If success + no retries (idk how else to check for test level retries
1018*da0073e9SAndroid Build Coastguard Worker        # other than reparse xml), print only what tests ran
1019*da0073e9SAndroid Build Coastguard Worker        print_to_stderr(
1020*da0073e9SAndroid Build Coastguard Worker            f"\n{test} was successful, full logs can be found in artifacts with path {new_file}"
1021*da0073e9SAndroid Build Coastguard Worker        )
1022*da0073e9SAndroid Build Coastguard Worker        for line in full_text.splitlines():
1023*da0073e9SAndroid Build Coastguard Worker            if re.search("Running .* items in this shard:", line):
1024*da0073e9SAndroid Build Coastguard Worker                print_to_stderr(line.rstrip())
1025*da0073e9SAndroid Build Coastguard Worker        print_to_stderr("")
1026*da0073e9SAndroid Build Coastguard Worker        return
1027*da0073e9SAndroid Build Coastguard Worker
1028*da0073e9SAndroid Build Coastguard Worker    # otherwise: print entire file
1029*da0073e9SAndroid Build Coastguard Worker    print_to_stderr(f"\nPRINTING LOG FILE of {test} ({new_file})")
1030*da0073e9SAndroid Build Coastguard Worker    print_to_stderr(full_text)
1031*da0073e9SAndroid Build Coastguard Worker    print_to_stderr(f"FINISHED PRINTING LOG FILE of {test} ({new_file})\n")
1032*da0073e9SAndroid Build Coastguard Worker
1033*da0073e9SAndroid Build Coastguard Worker
1034*da0073e9SAndroid Build Coastguard Workerdef get_pytest_args(options, is_cpp_test=False, is_distributed_test=False):
1035*da0073e9SAndroid Build Coastguard Worker    if RERUN_DISABLED_TESTS:
1036*da0073e9SAndroid Build Coastguard Worker        # Distributed tests are too slow, so running them x50 will cause the jobs to timeout after
1037*da0073e9SAndroid Build Coastguard Worker        # 3+ hours. So, let's opt for less number of reruns. We need at least 150 instances of the
1038*da0073e9SAndroid Build Coastguard Worker        # test every 2 weeks to satisfy the Rockset query (15 x 14 = 210). The same logic applies
1039*da0073e9SAndroid Build Coastguard Worker        # to ASAN, which is also slow
1040*da0073e9SAndroid Build Coastguard Worker        count = 15 if is_distributed_test or TEST_WITH_ASAN else 50
1041*da0073e9SAndroid Build Coastguard Worker        # When under rerun-disabled-tests mode, run the same tests multiple times to determine their
1042*da0073e9SAndroid Build Coastguard Worker        # flakiness status. Default to 50 re-runs
1043*da0073e9SAndroid Build Coastguard Worker        rerun_options = ["--flake-finder", f"--flake-runs={count}"]
1044*da0073e9SAndroid Build Coastguard Worker    else:
1045*da0073e9SAndroid Build Coastguard Worker        # When under the normal mode, retry a failed test 2 more times. -x means stop at the first
1046*da0073e9SAndroid Build Coastguard Worker        # failure
1047*da0073e9SAndroid Build Coastguard Worker        rerun_options = ["-x", "--reruns=2"]
1048*da0073e9SAndroid Build Coastguard Worker
1049*da0073e9SAndroid Build Coastguard Worker    pytest_args = [
1050*da0073e9SAndroid Build Coastguard Worker        "-vv",
1051*da0073e9SAndroid Build Coastguard Worker        "-rfEX",
1052*da0073e9SAndroid Build Coastguard Worker    ]
1053*da0073e9SAndroid Build Coastguard Worker    if not is_cpp_test:
1054*da0073e9SAndroid Build Coastguard Worker        # C++ tests need to be run with pytest directly, not via python
1055*da0073e9SAndroid Build Coastguard Worker        # We have a custom pytest shard that conflicts with the normal plugin
1056*da0073e9SAndroid Build Coastguard Worker        pytest_args.extend(["-p", "no:xdist", "--use-pytest"])
1057*da0073e9SAndroid Build Coastguard Worker    else:
1058*da0073e9SAndroid Build Coastguard Worker        # Use pytext-dist to run C++ tests in parallel as running them sequentially using run_test
1059*da0073e9SAndroid Build Coastguard Worker        # is much slower than running them directly
1060*da0073e9SAndroid Build Coastguard Worker        pytest_args.extend(["-n", str(NUM_PROCS)])
1061*da0073e9SAndroid Build Coastguard Worker
1062*da0073e9SAndroid Build Coastguard Worker        if IS_CI:
1063*da0073e9SAndroid Build Coastguard Worker            # Add the option to generate XML test report here as C++ tests
1064*da0073e9SAndroid Build Coastguard Worker            # won't go into common_utils
1065*da0073e9SAndroid Build Coastguard Worker            test_report_path = get_report_path(pytest=True)
1066*da0073e9SAndroid Build Coastguard Worker            pytest_args.extend(["--junit-xml-reruns", test_report_path])
1067*da0073e9SAndroid Build Coastguard Worker
1068*da0073e9SAndroid Build Coastguard Worker    if options.pytest_k_expr:
1069*da0073e9SAndroid Build Coastguard Worker        pytest_args.extend(["-k", options.pytest_k_expr])
1070*da0073e9SAndroid Build Coastguard Worker
1071*da0073e9SAndroid Build Coastguard Worker    pytest_args.extend(rerun_options)
1072*da0073e9SAndroid Build Coastguard Worker    return pytest_args
1073*da0073e9SAndroid Build Coastguard Worker
1074*da0073e9SAndroid Build Coastguard Worker
1075*da0073e9SAndroid Build Coastguard Workerdef run_ci_sanity_check(test: ShardedTest, test_directory, options):
1076*da0073e9SAndroid Build Coastguard Worker    assert (
1077*da0073e9SAndroid Build Coastguard Worker        test.name == "test_ci_sanity_check_fail"
1078*da0073e9SAndroid Build Coastguard Worker    ), f"This handler only works for test_ci_sanity_check_fail, got {test.name}"
1079*da0073e9SAndroid Build Coastguard Worker    ret_code = run_test(test, test_directory, options, print_log=False)
1080*da0073e9SAndroid Build Coastguard Worker    # This test should fail
1081*da0073e9SAndroid Build Coastguard Worker    if ret_code != 1:
1082*da0073e9SAndroid Build Coastguard Worker        return 1
1083*da0073e9SAndroid Build Coastguard Worker    test_reports_dir = str(REPO_ROOT / "test/test-reports")
1084*da0073e9SAndroid Build Coastguard Worker    # Delete the log files and xmls generated by the test
1085*da0073e9SAndroid Build Coastguard Worker    for file in glob.glob(f"{test_reports_dir}/{test.name}*.log"):
1086*da0073e9SAndroid Build Coastguard Worker        os.remove(file)
1087*da0073e9SAndroid Build Coastguard Worker    for dirname in glob.glob(f"{test_reports_dir}/**/{test.name}"):
1088*da0073e9SAndroid Build Coastguard Worker        shutil.rmtree(dirname)
1089*da0073e9SAndroid Build Coastguard Worker    return 0
1090*da0073e9SAndroid Build Coastguard Worker
1091*da0073e9SAndroid Build Coastguard Worker
1092*da0073e9SAndroid Build Coastguard WorkerCUSTOM_HANDLERS = {
1093*da0073e9SAndroid Build Coastguard Worker    "test_cuda_primary_ctx": run_test_with_subprocess,
1094*da0073e9SAndroid Build Coastguard Worker    "test_cuda_nvml_based_avail": run_test_with_subprocess,
1095*da0073e9SAndroid Build Coastguard Worker    "test_cuda_trace": run_test_with_subprocess,
1096*da0073e9SAndroid Build Coastguard Worker    "test_cpp_extensions_aot_no_ninja": test_cpp_extensions_aot_no_ninja,
1097*da0073e9SAndroid Build Coastguard Worker    "test_cpp_extensions_aot_ninja": test_cpp_extensions_aot_ninja,
1098*da0073e9SAndroid Build Coastguard Worker    "distributed/test_distributed_spawn": test_distributed,
1099*da0073e9SAndroid Build Coastguard Worker    "distributed/algorithms/quantization/test_quantization": test_distributed,
1100*da0073e9SAndroid Build Coastguard Worker    "distributed/test_c10d_nccl": run_test_with_subprocess,
1101*da0073e9SAndroid Build Coastguard Worker    "distributed/test_c10d_gloo": run_test_with_subprocess,
1102*da0073e9SAndroid Build Coastguard Worker    "distributed/test_c10d_ucc": run_test_with_subprocess,
1103*da0073e9SAndroid Build Coastguard Worker    "distributed/test_c10d_common": run_test_with_subprocess,
1104*da0073e9SAndroid Build Coastguard Worker    "distributed/test_c10d_spawn_gloo": run_test_with_subprocess,
1105*da0073e9SAndroid Build Coastguard Worker    "distributed/test_c10d_spawn_nccl": run_test_with_subprocess,
1106*da0073e9SAndroid Build Coastguard Worker    "distributed/test_c10d_spawn_ucc": run_test_with_subprocess,
1107*da0073e9SAndroid Build Coastguard Worker    "distributed/test_store": run_test_with_subprocess,
1108*da0073e9SAndroid Build Coastguard Worker    "distributed/test_pg_wrapper": run_test_with_subprocess,
1109*da0073e9SAndroid Build Coastguard Worker    "distributed/rpc/test_faulty_agent": run_test_with_subprocess,
1110*da0073e9SAndroid Build Coastguard Worker    "distributed/rpc/test_tensorpipe_agent": run_test_with_subprocess,
1111*da0073e9SAndroid Build Coastguard Worker    "distributed/rpc/test_share_memory": run_test_with_subprocess,
1112*da0073e9SAndroid Build Coastguard Worker    "distributed/rpc/cuda/test_tensorpipe_agent": run_test_with_subprocess,
1113*da0073e9SAndroid Build Coastguard Worker    "doctests": run_doctests,
1114*da0073e9SAndroid Build Coastguard Worker    "test_ci_sanity_check_fail": run_ci_sanity_check,
1115*da0073e9SAndroid Build Coastguard Worker    "test_autoload_enable": test_autoload_enable,
1116*da0073e9SAndroid Build Coastguard Worker    "test_autoload_disable": test_autoload_disable,
1117*da0073e9SAndroid Build Coastguard Worker}
1118*da0073e9SAndroid Build Coastguard Worker
1119*da0073e9SAndroid Build Coastguard Worker
1120*da0073e9SAndroid Build Coastguard WorkerPYTEST_SKIP_RETRIES = {"test_public_bindings"}
1121*da0073e9SAndroid Build Coastguard Worker
1122*da0073e9SAndroid Build Coastguard Worker
1123*da0073e9SAndroid Build Coastguard Workerdef parse_args():
1124*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(
1125*da0073e9SAndroid Build Coastguard Worker        description="Run the PyTorch unit test suite",
1126*da0073e9SAndroid Build Coastguard Worker        epilog="where TESTS is any of: {}".format(", ".join(TESTS)),
1127*da0073e9SAndroid Build Coastguard Worker        formatter_class=argparse.RawTextHelpFormatter,
1128*da0073e9SAndroid Build Coastguard Worker    )
1129*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1130*da0073e9SAndroid Build Coastguard Worker        "-v",
1131*da0073e9SAndroid Build Coastguard Worker        "--verbose",
1132*da0073e9SAndroid Build Coastguard Worker        action="count",
1133*da0073e9SAndroid Build Coastguard Worker        default=0,
1134*da0073e9SAndroid Build Coastguard Worker        help="Print verbose information and test-by-test results",
1135*da0073e9SAndroid Build Coastguard Worker    )
1136*da0073e9SAndroid Build Coastguard Worker    if sys.version_info >= (3, 9):
1137*da0073e9SAndroid Build Coastguard Worker        parser.add_argument(
1138*da0073e9SAndroid Build Coastguard Worker            "--showlocals",
1139*da0073e9SAndroid Build Coastguard Worker            action=argparse.BooleanOptionalAction,
1140*da0073e9SAndroid Build Coastguard Worker            default=strtobool(os.environ.get("TEST_SHOWLOCALS", "False")),
1141*da0073e9SAndroid Build Coastguard Worker            help="Show local variables in tracebacks (default: True)",
1142*da0073e9SAndroid Build Coastguard Worker        )
1143*da0073e9SAndroid Build Coastguard Worker    else:
1144*da0073e9SAndroid Build Coastguard Worker        parser.add_argument(
1145*da0073e9SAndroid Build Coastguard Worker            "--showlocals",
1146*da0073e9SAndroid Build Coastguard Worker            action="store_true",
1147*da0073e9SAndroid Build Coastguard Worker            default=strtobool(os.environ.get("TEST_SHOWLOCALS", "False")),
1148*da0073e9SAndroid Build Coastguard Worker            help="Show local variables in tracebacks (default: True)",
1149*da0073e9SAndroid Build Coastguard Worker        )
1150*da0073e9SAndroid Build Coastguard Worker        parser.add_argument("--no-showlocals", dest="showlocals", action="store_false")
1151*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--jit", "--jit", action="store_true", help="run all jit tests")
1152*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1153*da0073e9SAndroid Build Coastguard Worker        "--distributed-tests",
1154*da0073e9SAndroid Build Coastguard Worker        "--distributed-tests",
1155*da0073e9SAndroid Build Coastguard Worker        action="store_true",
1156*da0073e9SAndroid Build Coastguard Worker        help="Run all distributed tests",
1157*da0073e9SAndroid Build Coastguard Worker    )
1158*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1159*da0073e9SAndroid Build Coastguard Worker        "--functorch",
1160*da0073e9SAndroid Build Coastguard Worker        "--functorch",
1161*da0073e9SAndroid Build Coastguard Worker        action="store_true",
1162*da0073e9SAndroid Build Coastguard Worker        help=(
1163*da0073e9SAndroid Build Coastguard Worker            "If this flag is present, we will only run functorch tests. "
1164*da0073e9SAndroid Build Coastguard Worker            "If this flag is not present, we will run all tests "
1165*da0073e9SAndroid Build Coastguard Worker            "(including functorch tests)."
1166*da0073e9SAndroid Build Coastguard Worker        ),
1167*da0073e9SAndroid Build Coastguard Worker    )
1168*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1169*da0073e9SAndroid Build Coastguard Worker        "--mps",
1170*da0073e9SAndroid Build Coastguard Worker        "--mps",
1171*da0073e9SAndroid Build Coastguard Worker        action="store_true",
1172*da0073e9SAndroid Build Coastguard Worker        help=("If this flag is present, we will only run test_mps and test_metal"),
1173*da0073e9SAndroid Build Coastguard Worker    )
1174*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1175*da0073e9SAndroid Build Coastguard Worker        "--xpu",
1176*da0073e9SAndroid Build Coastguard Worker        "--xpu",
1177*da0073e9SAndroid Build Coastguard Worker        action="store_true",
1178*da0073e9SAndroid Build Coastguard Worker        help=("If this flag is present, we will run xpu tests except XPU_BLOCK_LIST"),
1179*da0073e9SAndroid Build Coastguard Worker    )
1180*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1181*da0073e9SAndroid Build Coastguard Worker        "--cpp",
1182*da0073e9SAndroid Build Coastguard Worker        "--cpp",
1183*da0073e9SAndroid Build Coastguard Worker        action="store_true",
1184*da0073e9SAndroid Build Coastguard Worker        help=("If this flag is present, we will only run C++ tests"),
1185*da0073e9SAndroid Build Coastguard Worker    )
1186*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1187*da0073e9SAndroid Build Coastguard Worker        "-core",
1188*da0073e9SAndroid Build Coastguard Worker        "--core",
1189*da0073e9SAndroid Build Coastguard Worker        action="store_true",
1190*da0073e9SAndroid Build Coastguard Worker        help="Only run core tests, or tests that validate PyTorch's ops, modules,"
1191*da0073e9SAndroid Build Coastguard Worker        "and autograd. They are defined by CORE_TEST_LIST.",
1192*da0073e9SAndroid Build Coastguard Worker    )
1193*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1194*da0073e9SAndroid Build Coastguard Worker        "--onnx",
1195*da0073e9SAndroid Build Coastguard Worker        "--onnx",
1196*da0073e9SAndroid Build Coastguard Worker        action="store_true",
1197*da0073e9SAndroid Build Coastguard Worker        help=(
1198*da0073e9SAndroid Build Coastguard Worker            "Only run ONNX tests, or tests that validate PyTorch's ONNX export. "
1199*da0073e9SAndroid Build Coastguard Worker            "If this flag is not present, we will exclude ONNX tests."
1200*da0073e9SAndroid Build Coastguard Worker        ),
1201*da0073e9SAndroid Build Coastguard Worker    )
1202*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1203*da0073e9SAndroid Build Coastguard Worker        "-k",
1204*da0073e9SAndroid Build Coastguard Worker        "--pytest-k-expr",
1205*da0073e9SAndroid Build Coastguard Worker        default="",
1206*da0073e9SAndroid Build Coastguard Worker        help="Pass to pytest as its -k expr argument",
1207*da0073e9SAndroid Build Coastguard Worker    )
1208*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1209*da0073e9SAndroid Build Coastguard Worker        "-c",
1210*da0073e9SAndroid Build Coastguard Worker        "--coverage",
1211*da0073e9SAndroid Build Coastguard Worker        action="store_true",
1212*da0073e9SAndroid Build Coastguard Worker        help="enable coverage",
1213*da0073e9SAndroid Build Coastguard Worker        default=PYTORCH_COLLECT_COVERAGE,
1214*da0073e9SAndroid Build Coastguard Worker    )
1215*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1216*da0073e9SAndroid Build Coastguard Worker        "-i",
1217*da0073e9SAndroid Build Coastguard Worker        "--include",
1218*da0073e9SAndroid Build Coastguard Worker        nargs="+",
1219*da0073e9SAndroid Build Coastguard Worker        choices=TestChoices(TESTS),
1220*da0073e9SAndroid Build Coastguard Worker        default=TESTS,
1221*da0073e9SAndroid Build Coastguard Worker        metavar="TESTS",
1222*da0073e9SAndroid Build Coastguard Worker        help="select a set of tests to include (defaults to ALL tests)."
1223*da0073e9SAndroid Build Coastguard Worker        " tests must be a part of the TESTS list defined in run_test.py",
1224*da0073e9SAndroid Build Coastguard Worker    )
1225*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1226*da0073e9SAndroid Build Coastguard Worker        "-x",
1227*da0073e9SAndroid Build Coastguard Worker        "--exclude",
1228*da0073e9SAndroid Build Coastguard Worker        nargs="+",
1229*da0073e9SAndroid Build Coastguard Worker        choices=TESTS,
1230*da0073e9SAndroid Build Coastguard Worker        metavar="TESTS",
1231*da0073e9SAndroid Build Coastguard Worker        default=[],
1232*da0073e9SAndroid Build Coastguard Worker        help="select a set of tests to exclude",
1233*da0073e9SAndroid Build Coastguard Worker    )
1234*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1235*da0073e9SAndroid Build Coastguard Worker        "--ignore-win-blocklist",
1236*da0073e9SAndroid Build Coastguard Worker        action="store_true",
1237*da0073e9SAndroid Build Coastguard Worker        help="always run blocklisted windows tests",
1238*da0073e9SAndroid Build Coastguard Worker    )
1239*da0073e9SAndroid Build Coastguard Worker    # NS: Disable target determination until it can be made more reliable
1240*da0073e9SAndroid Build Coastguard Worker    # parser.add_argument(
1241*da0073e9SAndroid Build Coastguard Worker    #     "--determine-from",
1242*da0073e9SAndroid Build Coastguard Worker    #     help="File of affected source filenames to determine which tests to run.",
1243*da0073e9SAndroid Build Coastguard Worker    # )
1244*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1245*da0073e9SAndroid Build Coastguard Worker        "--continue-through-error",
1246*da0073e9SAndroid Build Coastguard Worker        "--keep-going",
1247*da0073e9SAndroid Build Coastguard Worker        action="store_true",
1248*da0073e9SAndroid Build Coastguard Worker        help="Runs the full test suite despite one of the tests failing",
1249*da0073e9SAndroid Build Coastguard Worker        default=strtobool(os.environ.get("CONTINUE_THROUGH_ERROR", "False")),
1250*da0073e9SAndroid Build Coastguard Worker    )
1251*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1252*da0073e9SAndroid Build Coastguard Worker        "--pipe-logs",
1253*da0073e9SAndroid Build Coastguard Worker        action="store_true",
1254*da0073e9SAndroid Build Coastguard Worker        help="Print logs to output file while running tests.  True if in CI and env var is not set",
1255*da0073e9SAndroid Build Coastguard Worker        default=IS_CI and not strtobool(os.environ.get("VERBOSE_TEST_LOGS", "False")),
1256*da0073e9SAndroid Build Coastguard Worker    )
1257*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1258*da0073e9SAndroid Build Coastguard Worker        "--enable-timeout",
1259*da0073e9SAndroid Build Coastguard Worker        action="store_true",
1260*da0073e9SAndroid Build Coastguard Worker        help="Set a timeout based on the test times json file.  Only works if there are test times available",
1261*da0073e9SAndroid Build Coastguard Worker        default=IS_CI and not strtobool(os.environ.get("NO_TEST_TIMEOUT", "False")),
1262*da0073e9SAndroid Build Coastguard Worker    )
1263*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1264*da0073e9SAndroid Build Coastguard Worker        "--enable-td",
1265*da0073e9SAndroid Build Coastguard Worker        action="store_true",
1266*da0073e9SAndroid Build Coastguard Worker        help="Enables removing tests based on TD",
1267*da0073e9SAndroid Build Coastguard Worker        default=IS_CI
1268*da0073e9SAndroid Build Coastguard Worker        and (
1269*da0073e9SAndroid Build Coastguard Worker            TEST_WITH_CROSSREF
1270*da0073e9SAndroid Build Coastguard Worker            or TEST_WITH_ASAN
1271*da0073e9SAndroid Build Coastguard Worker            or (TEST_CONFIG == "distributed" and TEST_CUDA)
1272*da0073e9SAndroid Build Coastguard Worker            or (IS_WINDOWS and not TEST_CUDA)
1273*da0073e9SAndroid Build Coastguard Worker            or TEST_CONFIG == "nogpu_AVX512"
1274*da0073e9SAndroid Build Coastguard Worker            or TEST_CONFIG == "nogpu_NO_AVX2"
1275*da0073e9SAndroid Build Coastguard Worker            or TEST_CONFIG == "default"
1276*da0073e9SAndroid Build Coastguard Worker        )
1277*da0073e9SAndroid Build Coastguard Worker        and get_pr_number() is not None
1278*da0073e9SAndroid Build Coastguard Worker        and not strtobool(os.environ.get("NO_TD", "False"))
1279*da0073e9SAndroid Build Coastguard Worker        and not TEST_WITH_ROCM
1280*da0073e9SAndroid Build Coastguard Worker        and not IS_MACOS
1281*da0073e9SAndroid Build Coastguard Worker        and "xpu" not in BUILD_ENVIRONMENT
1282*da0073e9SAndroid Build Coastguard Worker        and "onnx" not in BUILD_ENVIRONMENT
1283*da0073e9SAndroid Build Coastguard Worker        and os.environ.get("GITHUB_WORKFLOW", "slow") in ("trunk", "pull"),
1284*da0073e9SAndroid Build Coastguard Worker    )
1285*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1286*da0073e9SAndroid Build Coastguard Worker        "--shard",
1287*da0073e9SAndroid Build Coastguard Worker        nargs=2,
1288*da0073e9SAndroid Build Coastguard Worker        type=int,
1289*da0073e9SAndroid Build Coastguard Worker        help="runs a shard of the tests (taking into account other selections), e.g., "
1290*da0073e9SAndroid Build Coastguard Worker        "--shard 2 3 will break up the selected tests into 3 shards and run the tests "
1291*da0073e9SAndroid Build Coastguard Worker        "in the 2nd shard (the first number should not exceed the second)",
1292*da0073e9SAndroid Build Coastguard Worker    )
1293*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1294*da0073e9SAndroid Build Coastguard Worker        "--exclude-jit-executor",
1295*da0073e9SAndroid Build Coastguard Worker        action="store_true",
1296*da0073e9SAndroid Build Coastguard Worker        help="exclude tests that are run for a specific jit config",
1297*da0073e9SAndroid Build Coastguard Worker    )
1298*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1299*da0073e9SAndroid Build Coastguard Worker        "--exclude-torch-export-tests",
1300*da0073e9SAndroid Build Coastguard Worker        action="store_true",
1301*da0073e9SAndroid Build Coastguard Worker        help="exclude torch export tests",
1302*da0073e9SAndroid Build Coastguard Worker    )
1303*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1304*da0073e9SAndroid Build Coastguard Worker        "--exclude-distributed-tests",
1305*da0073e9SAndroid Build Coastguard Worker        action="store_true",
1306*da0073e9SAndroid Build Coastguard Worker        help="exclude distributed tests",
1307*da0073e9SAndroid Build Coastguard Worker    )
1308*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1309*da0073e9SAndroid Build Coastguard Worker        "--exclude-inductor-tests",
1310*da0073e9SAndroid Build Coastguard Worker        action="store_true",
1311*da0073e9SAndroid Build Coastguard Worker        help="exclude inductor tests",
1312*da0073e9SAndroid Build Coastguard Worker    )
1313*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1314*da0073e9SAndroid Build Coastguard Worker        "--dry-run",
1315*da0073e9SAndroid Build Coastguard Worker        action="store_true",
1316*da0073e9SAndroid Build Coastguard Worker        help="Only list the test that will run.",
1317*da0073e9SAndroid Build Coastguard Worker    )
1318*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1319*da0073e9SAndroid Build Coastguard Worker        "--xdoctest-command",
1320*da0073e9SAndroid Build Coastguard Worker        default="all",
1321*da0073e9SAndroid Build Coastguard Worker        help=(
1322*da0073e9SAndroid Build Coastguard Worker            "Control the specific doctest action. "
1323*da0073e9SAndroid Build Coastguard Worker            "Use 'list' to simply parse doctests and check syntax. "
1324*da0073e9SAndroid Build Coastguard Worker            "Use 'all' to execute all doctests or specify a specific "
1325*da0073e9SAndroid Build Coastguard Worker            "doctest to run"
1326*da0073e9SAndroid Build Coastguard Worker        ),
1327*da0073e9SAndroid Build Coastguard Worker    )
1328*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1329*da0073e9SAndroid Build Coastguard Worker        "--no-translation-validation",
1330*da0073e9SAndroid Build Coastguard Worker        action="store_false",
1331*da0073e9SAndroid Build Coastguard Worker        help="Run tests without translation validation.",
1332*da0073e9SAndroid Build Coastguard Worker    )
1333*da0073e9SAndroid Build Coastguard Worker
1334*da0073e9SAndroid Build Coastguard Worker    group = parser.add_mutually_exclusive_group()
1335*da0073e9SAndroid Build Coastguard Worker    group.add_argument(
1336*da0073e9SAndroid Build Coastguard Worker        "--dynamo",
1337*da0073e9SAndroid Build Coastguard Worker        action="store_true",
1338*da0073e9SAndroid Build Coastguard Worker        help="Run tests with TorchDynamo+EagerBackend turned on",
1339*da0073e9SAndroid Build Coastguard Worker    )
1340*da0073e9SAndroid Build Coastguard Worker    group.add_argument(
1341*da0073e9SAndroid Build Coastguard Worker        "--inductor",
1342*da0073e9SAndroid Build Coastguard Worker        action="store_true",
1343*da0073e9SAndroid Build Coastguard Worker        help="Run tests with TorchInductor turned on",
1344*da0073e9SAndroid Build Coastguard Worker    )
1345*da0073e9SAndroid Build Coastguard Worker
1346*da0073e9SAndroid Build Coastguard Worker    args, extra = parser.parse_known_args()
1347*da0073e9SAndroid Build Coastguard Worker    if "--" in extra:
1348*da0073e9SAndroid Build Coastguard Worker        extra.remove("--")
1349*da0073e9SAndroid Build Coastguard Worker    args.additional_args = extra
1350*da0073e9SAndroid Build Coastguard Worker    return args
1351*da0073e9SAndroid Build Coastguard Worker
1352*da0073e9SAndroid Build Coastguard Worker
1353*da0073e9SAndroid Build Coastguard Workerdef exclude_tests(
1354*da0073e9SAndroid Build Coastguard Worker    exclude_list, selected_tests, exclude_message=None, exact_match=False
1355*da0073e9SAndroid Build Coastguard Worker):
1356*da0073e9SAndroid Build Coastguard Worker    for exclude_test in exclude_list:
1357*da0073e9SAndroid Build Coastguard Worker        tests_copy = selected_tests[:]
1358*da0073e9SAndroid Build Coastguard Worker        for test in tests_copy:
1359*da0073e9SAndroid Build Coastguard Worker            if (
1360*da0073e9SAndroid Build Coastguard Worker                not exact_match and test.startswith(exclude_test)
1361*da0073e9SAndroid Build Coastguard Worker            ) or test == exclude_test:
1362*da0073e9SAndroid Build Coastguard Worker                if exclude_message is not None:
1363*da0073e9SAndroid Build Coastguard Worker                    print_to_stderr(f"Excluding {test} {exclude_message}")
1364*da0073e9SAndroid Build Coastguard Worker                selected_tests.remove(test)
1365*da0073e9SAndroid Build Coastguard Worker    return selected_tests
1366*da0073e9SAndroid Build Coastguard Worker
1367*da0073e9SAndroid Build Coastguard Worker
1368*da0073e9SAndroid Build Coastguard Workerdef must_serial(file: Union[str, ShardedTest]) -> bool:
1369*da0073e9SAndroid Build Coastguard Worker    if isinstance(file, ShardedTest):
1370*da0073e9SAndroid Build Coastguard Worker        file = file.name
1371*da0073e9SAndroid Build Coastguard Worker    return (
1372*da0073e9SAndroid Build Coastguard Worker        os.getenv("PYTORCH_TEST_RUN_EVERYTHING_IN_SERIAL", "0") == "1"
1373*da0073e9SAndroid Build Coastguard Worker        or DISTRIBUTED_TEST_PREFIX in os.getenv("TEST_CONFIG", "")
1374*da0073e9SAndroid Build Coastguard Worker        or DISTRIBUTED_TEST_PREFIX in file
1375*da0073e9SAndroid Build Coastguard Worker        or file in CUSTOM_HANDLERS
1376*da0073e9SAndroid Build Coastguard Worker        or file in RUN_PARALLEL_BLOCKLIST
1377*da0073e9SAndroid Build Coastguard Worker        or file in CI_SERIAL_LIST
1378*da0073e9SAndroid Build Coastguard Worker        or file in JIT_EXECUTOR_TESTS
1379*da0073e9SAndroid Build Coastguard Worker        or file in ONNX_SERIAL_LIST
1380*da0073e9SAndroid Build Coastguard Worker        or NUM_PROCS == 1
1381*da0073e9SAndroid Build Coastguard Worker    )
1382*da0073e9SAndroid Build Coastguard Worker
1383*da0073e9SAndroid Build Coastguard Worker
1384*da0073e9SAndroid Build Coastguard Workerdef can_run_in_pytest(test):
1385*da0073e9SAndroid Build Coastguard Worker    return os.getenv("PYTORCH_TEST_DO_NOT_USE_PYTEST", "0") == "0"
1386*da0073e9SAndroid Build Coastguard Worker
1387*da0073e9SAndroid Build Coastguard Worker
1388*da0073e9SAndroid Build Coastguard Workerdef get_selected_tests(options) -> List[str]:
1389*da0073e9SAndroid Build Coastguard Worker    selected_tests = options.include
1390*da0073e9SAndroid Build Coastguard Worker
1391*da0073e9SAndroid Build Coastguard Worker    # filter if there's JIT only and distributed only test options
1392*da0073e9SAndroid Build Coastguard Worker    if options.jit:
1393*da0073e9SAndroid Build Coastguard Worker        selected_tests = list(
1394*da0073e9SAndroid Build Coastguard Worker            filter(lambda test_name: "jit" in test_name, selected_tests)
1395*da0073e9SAndroid Build Coastguard Worker        )
1396*da0073e9SAndroid Build Coastguard Worker
1397*da0073e9SAndroid Build Coastguard Worker    if options.distributed_tests:
1398*da0073e9SAndroid Build Coastguard Worker        selected_tests = list(
1399*da0073e9SAndroid Build Coastguard Worker            filter(lambda test_name: test_name in DISTRIBUTED_TESTS, selected_tests)
1400*da0073e9SAndroid Build Coastguard Worker        )
1401*da0073e9SAndroid Build Coastguard Worker
1402*da0073e9SAndroid Build Coastguard Worker    # Filter to only run core tests when --core option is specified
1403*da0073e9SAndroid Build Coastguard Worker    if options.core:
1404*da0073e9SAndroid Build Coastguard Worker        selected_tests = list(
1405*da0073e9SAndroid Build Coastguard Worker            filter(lambda test_name: test_name in CORE_TEST_LIST, selected_tests)
1406*da0073e9SAndroid Build Coastguard Worker        )
1407*da0073e9SAndroid Build Coastguard Worker
1408*da0073e9SAndroid Build Coastguard Worker    # Filter to only run functorch tests when --functorch option is specified
1409*da0073e9SAndroid Build Coastguard Worker    if options.functorch:
1410*da0073e9SAndroid Build Coastguard Worker        selected_tests = [tname for tname in selected_tests if tname in FUNCTORCH_TESTS]
1411*da0073e9SAndroid Build Coastguard Worker
1412*da0073e9SAndroid Build Coastguard Worker    if options.cpp:
1413*da0073e9SAndroid Build Coastguard Worker        selected_tests = [tname for tname in selected_tests if tname in CPP_TESTS]
1414*da0073e9SAndroid Build Coastguard Worker    else:
1415*da0073e9SAndroid Build Coastguard Worker        # Exclude all C++ tests otherwise as they are still handled differently
1416*da0073e9SAndroid Build Coastguard Worker        # than Python test at the moment
1417*da0073e9SAndroid Build Coastguard Worker        options.exclude.extend(CPP_TESTS)
1418*da0073e9SAndroid Build Coastguard Worker
1419*da0073e9SAndroid Build Coastguard Worker    if options.mps:
1420*da0073e9SAndroid Build Coastguard Worker        selected_tests = ["test_mps", "test_metal", "test_modules", "test_nn"]
1421*da0073e9SAndroid Build Coastguard Worker    else:
1422*da0073e9SAndroid Build Coastguard Worker        # Exclude all mps tests otherwise
1423*da0073e9SAndroid Build Coastguard Worker        options.exclude.extend(["test_mps", "test_metal"])
1424*da0073e9SAndroid Build Coastguard Worker
1425*da0073e9SAndroid Build Coastguard Worker    if options.xpu:
1426*da0073e9SAndroid Build Coastguard Worker        selected_tests = exclude_tests(XPU_BLOCKLIST, selected_tests, "on XPU")
1427*da0073e9SAndroid Build Coastguard Worker    else:
1428*da0073e9SAndroid Build Coastguard Worker        # Exclude all xpu specifc tests otherwise
1429*da0073e9SAndroid Build Coastguard Worker        options.exclude.extend(XPU_TEST)
1430*da0073e9SAndroid Build Coastguard Worker
1431*da0073e9SAndroid Build Coastguard Worker    # Filter to only run onnx tests when --onnx option is specified
1432*da0073e9SAndroid Build Coastguard Worker    onnx_tests = [tname for tname in selected_tests if tname in ONNX_TESTS]
1433*da0073e9SAndroid Build Coastguard Worker    if options.onnx:
1434*da0073e9SAndroid Build Coastguard Worker        selected_tests = onnx_tests
1435*da0073e9SAndroid Build Coastguard Worker    else:
1436*da0073e9SAndroid Build Coastguard Worker        # Exclude all onnx tests otherwise
1437*da0073e9SAndroid Build Coastguard Worker        options.exclude.extend(onnx_tests)
1438*da0073e9SAndroid Build Coastguard Worker
1439*da0073e9SAndroid Build Coastguard Worker    # process exclusion
1440*da0073e9SAndroid Build Coastguard Worker    if options.exclude_jit_executor:
1441*da0073e9SAndroid Build Coastguard Worker        options.exclude.extend(JIT_EXECUTOR_TESTS)
1442*da0073e9SAndroid Build Coastguard Worker
1443*da0073e9SAndroid Build Coastguard Worker    if options.exclude_distributed_tests:
1444*da0073e9SAndroid Build Coastguard Worker        options.exclude.extend(DISTRIBUTED_TESTS)
1445*da0073e9SAndroid Build Coastguard Worker
1446*da0073e9SAndroid Build Coastguard Worker    if options.exclude_inductor_tests:
1447*da0073e9SAndroid Build Coastguard Worker        options.exclude.extend(INDUCTOR_TESTS)
1448*da0073e9SAndroid Build Coastguard Worker
1449*da0073e9SAndroid Build Coastguard Worker    if options.exclude_torch_export_tests:
1450*da0073e9SAndroid Build Coastguard Worker        options.exclude.extend(TORCH_EXPORT_TESTS)
1451*da0073e9SAndroid Build Coastguard Worker
1452*da0073e9SAndroid Build Coastguard Worker    # these tests failing in CUDA 11.6 temporary disabling. issue https://github.com/pytorch/pytorch/issues/75375
1453*da0073e9SAndroid Build Coastguard Worker    if torch.version.cuda is not None:
1454*da0073e9SAndroid Build Coastguard Worker        options.exclude.extend(["distributions/test_constraints"])
1455*da0073e9SAndroid Build Coastguard Worker
1456*da0073e9SAndroid Build Coastguard Worker    # these tests failing in Python 3.12 temporarily disabling
1457*da0073e9SAndroid Build Coastguard Worker    if sys.version_info >= (3, 12):
1458*da0073e9SAndroid Build Coastguard Worker        options.exclude.extend(
1459*da0073e9SAndroid Build Coastguard Worker            [
1460*da0073e9SAndroid Build Coastguard Worker                "functorch/test_dims",
1461*da0073e9SAndroid Build Coastguard Worker                "functorch/test_rearrange",
1462*da0073e9SAndroid Build Coastguard Worker                "functorch/test_parsing",
1463*da0073e9SAndroid Build Coastguard Worker                "functorch/test_memory_efficient_fusion",
1464*da0073e9SAndroid Build Coastguard Worker                "torch_np/numpy_tests/core/test_multiarray",
1465*da0073e9SAndroid Build Coastguard Worker            ]
1466*da0073e9SAndroid Build Coastguard Worker        )
1467*da0073e9SAndroid Build Coastguard Worker
1468*da0073e9SAndroid Build Coastguard Worker    selected_tests = exclude_tests(options.exclude, selected_tests)
1469*da0073e9SAndroid Build Coastguard Worker
1470*da0073e9SAndroid Build Coastguard Worker    if sys.platform == "win32" and not options.ignore_win_blocklist:
1471*da0073e9SAndroid Build Coastguard Worker        target_arch = os.environ.get("VSCMD_ARG_TGT_ARCH")
1472*da0073e9SAndroid Build Coastguard Worker        if target_arch != "x64":
1473*da0073e9SAndroid Build Coastguard Worker            WINDOWS_BLOCKLIST.append("cpp_extensions_aot_no_ninja")
1474*da0073e9SAndroid Build Coastguard Worker            WINDOWS_BLOCKLIST.append("cpp_extensions_aot_ninja")
1475*da0073e9SAndroid Build Coastguard Worker            WINDOWS_BLOCKLIST.append("cpp_extensions_jit")
1476*da0073e9SAndroid Build Coastguard Worker            WINDOWS_BLOCKLIST.append("jit")
1477*da0073e9SAndroid Build Coastguard Worker            WINDOWS_BLOCKLIST.append("jit_fuser")
1478*da0073e9SAndroid Build Coastguard Worker
1479*da0073e9SAndroid Build Coastguard Worker        selected_tests = exclude_tests(WINDOWS_BLOCKLIST, selected_tests, "on Windows")
1480*da0073e9SAndroid Build Coastguard Worker
1481*da0073e9SAndroid Build Coastguard Worker    elif TEST_WITH_ROCM:
1482*da0073e9SAndroid Build Coastguard Worker        selected_tests = exclude_tests(ROCM_BLOCKLIST, selected_tests, "on ROCm")
1483*da0073e9SAndroid Build Coastguard Worker
1484*da0073e9SAndroid Build Coastguard Worker    # skip all distributed tests if distributed package is not available.
1485*da0073e9SAndroid Build Coastguard Worker    if not dist.is_available():
1486*da0073e9SAndroid Build Coastguard Worker        selected_tests = exclude_tests(
1487*da0073e9SAndroid Build Coastguard Worker            DISTRIBUTED_TESTS,
1488*da0073e9SAndroid Build Coastguard Worker            selected_tests,
1489*da0073e9SAndroid Build Coastguard Worker            "PyTorch is built without distributed support.",
1490*da0073e9SAndroid Build Coastguard Worker        )
1491*da0073e9SAndroid Build Coastguard Worker
1492*da0073e9SAndroid Build Coastguard Worker    # skip tests that require LAPACK when it's not available
1493*da0073e9SAndroid Build Coastguard Worker    if not torch._C.has_lapack:
1494*da0073e9SAndroid Build Coastguard Worker        selected_tests = exclude_tests(
1495*da0073e9SAndroid Build Coastguard Worker            TESTS_REQUIRING_LAPACK,
1496*da0073e9SAndroid Build Coastguard Worker            selected_tests,
1497*da0073e9SAndroid Build Coastguard Worker            "PyTorch is built without LAPACK support.",
1498*da0073e9SAndroid Build Coastguard Worker        )
1499*da0073e9SAndroid Build Coastguard Worker
1500*da0073e9SAndroid Build Coastguard Worker    if TEST_WITH_SLOW_GRADCHECK:
1501*da0073e9SAndroid Build Coastguard Worker        selected_tests = exclude_tests(
1502*da0073e9SAndroid Build Coastguard Worker            TESTS_NOT_USING_GRADCHECK,
1503*da0073e9SAndroid Build Coastguard Worker            selected_tests,
1504*da0073e9SAndroid Build Coastguard Worker            "Running in slow gradcheck mode, skipping tests "
1505*da0073e9SAndroid Build Coastguard Worker            "that don't use gradcheck.",
1506*da0073e9SAndroid Build Coastguard Worker            exact_match=True,
1507*da0073e9SAndroid Build Coastguard Worker        )
1508*da0073e9SAndroid Build Coastguard Worker
1509*da0073e9SAndroid Build Coastguard Worker    selected_tests = [parse_test_module(x) for x in selected_tests]
1510*da0073e9SAndroid Build Coastguard Worker    return selected_tests
1511*da0073e9SAndroid Build Coastguard Worker
1512*da0073e9SAndroid Build Coastguard Worker
1513*da0073e9SAndroid Build Coastguard Workerdef load_test_times_from_file(file: str) -> Dict[str, Any]:
1514*da0073e9SAndroid Build Coastguard Worker    # Load previous test times to make sharding decisions
1515*da0073e9SAndroid Build Coastguard Worker    path = os.path.join(str(REPO_ROOT), file)
1516*da0073e9SAndroid Build Coastguard Worker    if not os.path.exists(path):
1517*da0073e9SAndroid Build Coastguard Worker        print_to_stderr(
1518*da0073e9SAndroid Build Coastguard Worker            f"::warning:: Failed to find test times file `{path}`. Using round robin sharding."
1519*da0073e9SAndroid Build Coastguard Worker        )
1520*da0073e9SAndroid Build Coastguard Worker        return {}
1521*da0073e9SAndroid Build Coastguard Worker
1522*da0073e9SAndroid Build Coastguard Worker    with open(path) as f:
1523*da0073e9SAndroid Build Coastguard Worker        test_times_file = cast(Dict[str, Any], json.load(f))
1524*da0073e9SAndroid Build Coastguard Worker    build_environment = os.environ.get("BUILD_ENVIRONMENT")
1525*da0073e9SAndroid Build Coastguard Worker    test_config = os.environ.get("TEST_CONFIG")
1526*da0073e9SAndroid Build Coastguard Worker    if test_config in test_times_file.get(build_environment, {}):
1527*da0073e9SAndroid Build Coastguard Worker        print_to_stderr("Found test times from artifacts")
1528*da0073e9SAndroid Build Coastguard Worker        return test_times_file[build_environment][test_config]
1529*da0073e9SAndroid Build Coastguard Worker    elif test_config in test_times_file["default"]:
1530*da0073e9SAndroid Build Coastguard Worker        print_to_stderr(
1531*da0073e9SAndroid Build Coastguard Worker            f"::warning:: Gathered no stats from artifacts for {build_environment} build env"
1532*da0073e9SAndroid Build Coastguard Worker            f" and {test_config} test config. Using default build env and {test_config} test config instead."
1533*da0073e9SAndroid Build Coastguard Worker        )
1534*da0073e9SAndroid Build Coastguard Worker        return test_times_file["default"][test_config]
1535*da0073e9SAndroid Build Coastguard Worker    else:
1536*da0073e9SAndroid Build Coastguard Worker        print_to_stderr(
1537*da0073e9SAndroid Build Coastguard Worker            f"::warning:: Gathered no stats from artifacts for build env {build_environment} build env"
1538*da0073e9SAndroid Build Coastguard Worker            f" and {test_config} test config. Using default build env and default test config instead."
1539*da0073e9SAndroid Build Coastguard Worker        )
1540*da0073e9SAndroid Build Coastguard Worker        return test_times_file["default"]["default"]
1541*da0073e9SAndroid Build Coastguard Worker
1542*da0073e9SAndroid Build Coastguard Worker
1543*da0073e9SAndroid Build Coastguard Workerdef load_test_file_times(
1544*da0073e9SAndroid Build Coastguard Worker    file: str = ADDITIONAL_CI_FILES_FOLDER / TEST_TIMES_FILE,
1545*da0073e9SAndroid Build Coastguard Worker) -> Dict[str, float]:
1546*da0073e9SAndroid Build Coastguard Worker    return cast(Dict[str, float], load_test_times_from_file(file))
1547*da0073e9SAndroid Build Coastguard Worker
1548*da0073e9SAndroid Build Coastguard Worker
1549*da0073e9SAndroid Build Coastguard Workerdef load_test_class_times(
1550*da0073e9SAndroid Build Coastguard Worker    file: str = ADDITIONAL_CI_FILES_FOLDER / TEST_CLASS_TIMES_FILE,
1551*da0073e9SAndroid Build Coastguard Worker) -> Dict[str, Dict[str, float]]:
1552*da0073e9SAndroid Build Coastguard Worker    return cast(Dict[str, Dict[str, float]], load_test_times_from_file(file))
1553*da0073e9SAndroid Build Coastguard Worker
1554*da0073e9SAndroid Build Coastguard Worker
1555*da0073e9SAndroid Build Coastguard Workerdef get_sharding_opts(options) -> Tuple[int, int]:
1556*da0073e9SAndroid Build Coastguard Worker    which_shard, num_shards = 1, 1
1557*da0073e9SAndroid Build Coastguard Worker    if options.shard:
1558*da0073e9SAndroid Build Coastguard Worker        assert len(options.shard) == 2, "Unexpected shard format"
1559*da0073e9SAndroid Build Coastguard Worker        assert min(options.shard) > 0, "Shards must be positive numbers"
1560*da0073e9SAndroid Build Coastguard Worker        which_shard, num_shards = options.shard
1561*da0073e9SAndroid Build Coastguard Worker        assert (
1562*da0073e9SAndroid Build Coastguard Worker            which_shard <= num_shards
1563*da0073e9SAndroid Build Coastguard Worker        ), "Selected shard must be less than or equal to total number of shards"
1564*da0073e9SAndroid Build Coastguard Worker
1565*da0073e9SAndroid Build Coastguard Worker    return (which_shard, num_shards)
1566*da0073e9SAndroid Build Coastguard Worker
1567*da0073e9SAndroid Build Coastguard Worker
1568*da0073e9SAndroid Build Coastguard Workerdef do_sharding(
1569*da0073e9SAndroid Build Coastguard Worker    options,
1570*da0073e9SAndroid Build Coastguard Worker    selected_tests: Sequence[TestRun],
1571*da0073e9SAndroid Build Coastguard Worker    test_file_times: Dict[str, float],
1572*da0073e9SAndroid Build Coastguard Worker    test_class_times: Dict[str, Dict[str, float]],
1573*da0073e9SAndroid Build Coastguard Worker    sort_by_time: bool = True,
1574*da0073e9SAndroid Build Coastguard Worker) -> Tuple[float, List[ShardedTest]]:
1575*da0073e9SAndroid Build Coastguard Worker    which_shard, num_shards = get_sharding_opts(options)
1576*da0073e9SAndroid Build Coastguard Worker
1577*da0073e9SAndroid Build Coastguard Worker    # Do sharding
1578*da0073e9SAndroid Build Coastguard Worker    shards = calculate_shards(
1579*da0073e9SAndroid Build Coastguard Worker        num_shards,
1580*da0073e9SAndroid Build Coastguard Worker        selected_tests,
1581*da0073e9SAndroid Build Coastguard Worker        test_file_times,
1582*da0073e9SAndroid Build Coastguard Worker        test_class_times=test_class_times,
1583*da0073e9SAndroid Build Coastguard Worker        must_serial=must_serial,
1584*da0073e9SAndroid Build Coastguard Worker        sort_by_time=sort_by_time,
1585*da0073e9SAndroid Build Coastguard Worker    )
1586*da0073e9SAndroid Build Coastguard Worker    return shards[which_shard - 1]
1587*da0073e9SAndroid Build Coastguard Worker
1588*da0073e9SAndroid Build Coastguard Worker
1589*da0073e9SAndroid Build Coastguard Workerclass TestFailure(NamedTuple):
1590*da0073e9SAndroid Build Coastguard Worker    test: TestRun
1591*da0073e9SAndroid Build Coastguard Worker    message: str
1592*da0073e9SAndroid Build Coastguard Worker
1593*da0073e9SAndroid Build Coastguard Worker
1594*da0073e9SAndroid Build Coastguard Workerdef run_test_module(
1595*da0073e9SAndroid Build Coastguard Worker    test: ShardedTest, test_directory: str, options
1596*da0073e9SAndroid Build Coastguard Worker) -> Optional[TestFailure]:
1597*da0073e9SAndroid Build Coastguard Worker    try:
1598*da0073e9SAndroid Build Coastguard Worker        maybe_set_hip_visible_devies()
1599*da0073e9SAndroid Build Coastguard Worker
1600*da0073e9SAndroid Build Coastguard Worker        test_name = test.name
1601*da0073e9SAndroid Build Coastguard Worker
1602*da0073e9SAndroid Build Coastguard Worker        # Printing the date here can help diagnose which tests are slow
1603*da0073e9SAndroid Build Coastguard Worker        print_to_stderr(f"Running {str(test)} ... [{datetime.now()}]")
1604*da0073e9SAndroid Build Coastguard Worker        handler = CUSTOM_HANDLERS.get(test_name, run_test)
1605*da0073e9SAndroid Build Coastguard Worker        return_code = handler(test, test_directory, options)
1606*da0073e9SAndroid Build Coastguard Worker        assert isinstance(return_code, int) and not isinstance(
1607*da0073e9SAndroid Build Coastguard Worker            return_code, bool
1608*da0073e9SAndroid Build Coastguard Worker        ), f"While running {str(test)} got non integer return code {return_code}"
1609*da0073e9SAndroid Build Coastguard Worker        if return_code == 0:
1610*da0073e9SAndroid Build Coastguard Worker            return None
1611*da0073e9SAndroid Build Coastguard Worker
1612*da0073e9SAndroid Build Coastguard Worker        message = f"{str(test)} failed!"
1613*da0073e9SAndroid Build Coastguard Worker        if return_code < 0:
1614*da0073e9SAndroid Build Coastguard Worker            # subprocess.Popen returns the child process' exit signal as
1615*da0073e9SAndroid Build Coastguard Worker            # return code -N, where N is the signal number.
1616*da0073e9SAndroid Build Coastguard Worker            signal_name = SIGNALS_TO_NAMES_DICT[-return_code]
1617*da0073e9SAndroid Build Coastguard Worker            message += f" Received signal: {signal_name}"
1618*da0073e9SAndroid Build Coastguard Worker        return TestFailure(test.test, message)
1619*da0073e9SAndroid Build Coastguard Worker    except Exception as e:
1620*da0073e9SAndroid Build Coastguard Worker        return TestFailure(test.test, f"{str(test)} failed! {e}")
1621*da0073e9SAndroid Build Coastguard Worker
1622*da0073e9SAndroid Build Coastguard Worker
1623*da0073e9SAndroid Build Coastguard Workerdef run_tests(
1624*da0073e9SAndroid Build Coastguard Worker    selected_tests: List[ShardedTest],
1625*da0073e9SAndroid Build Coastguard Worker    test_directory: str,
1626*da0073e9SAndroid Build Coastguard Worker    options,
1627*da0073e9SAndroid Build Coastguard Worker    failures: List[TestFailure],
1628*da0073e9SAndroid Build Coastguard Worker) -> None:
1629*da0073e9SAndroid Build Coastguard Worker    if len(selected_tests) == 0:
1630*da0073e9SAndroid Build Coastguard Worker        return
1631*da0073e9SAndroid Build Coastguard Worker
1632*da0073e9SAndroid Build Coastguard Worker    # parallel = in parallel with other files
1633*da0073e9SAndroid Build Coastguard Worker    # serial = this file on it's own.  The file might still be run in parallel with itself (ex test_ops)
1634*da0073e9SAndroid Build Coastguard Worker    selected_tests_parallel = [x for x in selected_tests if not must_serial(x)]
1635*da0073e9SAndroid Build Coastguard Worker    selected_tests_serial = [
1636*da0073e9SAndroid Build Coastguard Worker        x for x in selected_tests if x not in selected_tests_parallel
1637*da0073e9SAndroid Build Coastguard Worker    ]
1638*da0073e9SAndroid Build Coastguard Worker
1639*da0073e9SAndroid Build Coastguard Worker    # See Note [ROCm parallel CI testing]
1640*da0073e9SAndroid Build Coastguard Worker    pool = get_context("spawn").Pool(
1641*da0073e9SAndroid Build Coastguard Worker        NUM_PROCS, maxtasksperchild=None if torch.version.hip else 1
1642*da0073e9SAndroid Build Coastguard Worker    )
1643*da0073e9SAndroid Build Coastguard Worker
1644*da0073e9SAndroid Build Coastguard Worker    # NB: This is a hack to make conftest.py and files it depends on available
1645*da0073e9SAndroid Build Coastguard Worker    # on CPP_TESTS_DIR. We should see if the file could be turned into a
1646*da0073e9SAndroid Build Coastguard Worker    # full-fledge ptest plugin instead
1647*da0073e9SAndroid Build Coastguard Worker    conftest_files = [
1648*da0073e9SAndroid Build Coastguard Worker        "conftest.py",
1649*da0073e9SAndroid Build Coastguard Worker        "pytest_shard_custom.py",
1650*da0073e9SAndroid Build Coastguard Worker    ]
1651*da0073e9SAndroid Build Coastguard Worker    for conftest_file in conftest_files:
1652*da0073e9SAndroid Build Coastguard Worker        cpp_file = os.path.join(CPP_TESTS_DIR, conftest_file)
1653*da0073e9SAndroid Build Coastguard Worker        if (
1654*da0073e9SAndroid Build Coastguard Worker            options.cpp
1655*da0073e9SAndroid Build Coastguard Worker            and os.path.exists(CPP_TESTS_DIR)
1656*da0073e9SAndroid Build Coastguard Worker            and os.path.isdir(CPP_TESTS_DIR)
1657*da0073e9SAndroid Build Coastguard Worker            and not os.path.exists(cpp_file)
1658*da0073e9SAndroid Build Coastguard Worker        ):
1659*da0073e9SAndroid Build Coastguard Worker            shutil.copy(os.path.join(test_directory, conftest_file), cpp_file)
1660*da0073e9SAndroid Build Coastguard Worker
1661*da0073e9SAndroid Build Coastguard Worker    def handle_error_messages(failure: Optional[TestFailure]):
1662*da0073e9SAndroid Build Coastguard Worker        if failure is None:
1663*da0073e9SAndroid Build Coastguard Worker            return False
1664*da0073e9SAndroid Build Coastguard Worker        failures.append(failure)
1665*da0073e9SAndroid Build Coastguard Worker        print_to_stderr(failure.message)
1666*da0073e9SAndroid Build Coastguard Worker        return True
1667*da0073e9SAndroid Build Coastguard Worker
1668*da0073e9SAndroid Build Coastguard Worker    def parallel_test_completion_callback(failure):
1669*da0073e9SAndroid Build Coastguard Worker        test_failed = handle_error_messages(failure)
1670*da0073e9SAndroid Build Coastguard Worker        if (
1671*da0073e9SAndroid Build Coastguard Worker            test_failed
1672*da0073e9SAndroid Build Coastguard Worker            and not options.continue_through_error
1673*da0073e9SAndroid Build Coastguard Worker            and not RERUN_DISABLED_TESTS
1674*da0073e9SAndroid Build Coastguard Worker        ):
1675*da0073e9SAndroid Build Coastguard Worker            pool.terminate()
1676*da0073e9SAndroid Build Coastguard Worker
1677*da0073e9SAndroid Build Coastguard Worker    keep_going_message = (
1678*da0073e9SAndroid Build Coastguard Worker        "\n\nTip: You can keep running tests even on failure by passing --keep-going to run_test.py.\n"
1679*da0073e9SAndroid Build Coastguard Worker        "If running on CI, add the 'keep-going' label to your PR and rerun your jobs."
1680*da0073e9SAndroid Build Coastguard Worker    )
1681*da0073e9SAndroid Build Coastguard Worker
1682*da0073e9SAndroid Build Coastguard Worker    try:
1683*da0073e9SAndroid Build Coastguard Worker        for test in selected_tests_serial:
1684*da0073e9SAndroid Build Coastguard Worker            options_clone = copy.deepcopy(options)
1685*da0073e9SAndroid Build Coastguard Worker            if can_run_in_pytest(test):
1686*da0073e9SAndroid Build Coastguard Worker                options_clone.pytest = True
1687*da0073e9SAndroid Build Coastguard Worker            failure = run_test_module(test, test_directory, options_clone)
1688*da0073e9SAndroid Build Coastguard Worker            test_failed = handle_error_messages(failure)
1689*da0073e9SAndroid Build Coastguard Worker            if (
1690*da0073e9SAndroid Build Coastguard Worker                test_failed
1691*da0073e9SAndroid Build Coastguard Worker                and not options.continue_through_error
1692*da0073e9SAndroid Build Coastguard Worker                and not RERUN_DISABLED_TESTS
1693*da0073e9SAndroid Build Coastguard Worker            ):
1694*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(failure.message + keep_going_message)
1695*da0073e9SAndroid Build Coastguard Worker
1696*da0073e9SAndroid Build Coastguard Worker        # Run tests marked as serial first
1697*da0073e9SAndroid Build Coastguard Worker        for test in selected_tests_parallel:
1698*da0073e9SAndroid Build Coastguard Worker            options_clone = copy.deepcopy(options)
1699*da0073e9SAndroid Build Coastguard Worker            if can_run_in_pytest(test):
1700*da0073e9SAndroid Build Coastguard Worker                options_clone.pytest = True
1701*da0073e9SAndroid Build Coastguard Worker            options_clone.additional_args.extend(["-m", "serial"])
1702*da0073e9SAndroid Build Coastguard Worker            failure = run_test_module(test, test_directory, options_clone)
1703*da0073e9SAndroid Build Coastguard Worker            test_failed = handle_error_messages(failure)
1704*da0073e9SAndroid Build Coastguard Worker            if (
1705*da0073e9SAndroid Build Coastguard Worker                test_failed
1706*da0073e9SAndroid Build Coastguard Worker                and not options.continue_through_error
1707*da0073e9SAndroid Build Coastguard Worker                and not RERUN_DISABLED_TESTS
1708*da0073e9SAndroid Build Coastguard Worker            ):
1709*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(failure.message + keep_going_message)
1710*da0073e9SAndroid Build Coastguard Worker
1711*da0073e9SAndroid Build Coastguard Worker        os.environ["NUM_PARALLEL_PROCS"] = str(NUM_PROCS)
1712*da0073e9SAndroid Build Coastguard Worker        for test in selected_tests_parallel:
1713*da0073e9SAndroid Build Coastguard Worker            options_clone = copy.deepcopy(options)
1714*da0073e9SAndroid Build Coastguard Worker            if can_run_in_pytest(test):
1715*da0073e9SAndroid Build Coastguard Worker                options_clone.pytest = True
1716*da0073e9SAndroid Build Coastguard Worker            options_clone.additional_args.extend(["-m", "not serial"])
1717*da0073e9SAndroid Build Coastguard Worker            pool.apply_async(
1718*da0073e9SAndroid Build Coastguard Worker                run_test_module,
1719*da0073e9SAndroid Build Coastguard Worker                args=(test, test_directory, options_clone),
1720*da0073e9SAndroid Build Coastguard Worker                callback=parallel_test_completion_callback,
1721*da0073e9SAndroid Build Coastguard Worker            )
1722*da0073e9SAndroid Build Coastguard Worker        pool.close()
1723*da0073e9SAndroid Build Coastguard Worker        pool.join()
1724*da0073e9SAndroid Build Coastguard Worker        del os.environ["NUM_PARALLEL_PROCS"]
1725*da0073e9SAndroid Build Coastguard Worker
1726*da0073e9SAndroid Build Coastguard Worker    finally:
1727*da0073e9SAndroid Build Coastguard Worker        pool.terminate()
1728*da0073e9SAndroid Build Coastguard Worker        pool.join()
1729*da0073e9SAndroid Build Coastguard Worker
1730*da0073e9SAndroid Build Coastguard Worker    return
1731*da0073e9SAndroid Build Coastguard Worker
1732*da0073e9SAndroid Build Coastguard Worker
1733*da0073e9SAndroid Build Coastguard Workerdef check_pip_packages() -> None:
1734*da0073e9SAndroid Build Coastguard Worker    packages = [
1735*da0073e9SAndroid Build Coastguard Worker        "pytest-rerunfailures",
1736*da0073e9SAndroid Build Coastguard Worker        "pytest-flakefinder",
1737*da0073e9SAndroid Build Coastguard Worker        "pytest-xdist",
1738*da0073e9SAndroid Build Coastguard Worker    ]
1739*da0073e9SAndroid Build Coastguard Worker    installed_packages = [i.key for i in pkg_resources.working_set]
1740*da0073e9SAndroid Build Coastguard Worker    for package in packages:
1741*da0073e9SAndroid Build Coastguard Worker        if package not in installed_packages:
1742*da0073e9SAndroid Build Coastguard Worker            print_to_stderr(
1743*da0073e9SAndroid Build Coastguard Worker                f"Missing pip dependency: {package}, please run `pip install -r .ci/docker/requirements-ci.txt`"
1744*da0073e9SAndroid Build Coastguard Worker            )
1745*da0073e9SAndroid Build Coastguard Worker            sys.exit(1)
1746*da0073e9SAndroid Build Coastguard Worker
1747*da0073e9SAndroid Build Coastguard Worker
1748*da0073e9SAndroid Build Coastguard Workerdef main():
1749*da0073e9SAndroid Build Coastguard Worker    check_pip_packages()
1750*da0073e9SAndroid Build Coastguard Worker
1751*da0073e9SAndroid Build Coastguard Worker    options = parse_args()
1752*da0073e9SAndroid Build Coastguard Worker
1753*da0073e9SAndroid Build Coastguard Worker    # Include sharding info in all metrics
1754*da0073e9SAndroid Build Coastguard Worker    which_shard, num_shards = get_sharding_opts(options)
1755*da0073e9SAndroid Build Coastguard Worker    add_global_metric("shard", which_shard)
1756*da0073e9SAndroid Build Coastguard Worker    add_global_metric("num_shards", num_shards)
1757*da0073e9SAndroid Build Coastguard Worker
1758*da0073e9SAndroid Build Coastguard Worker    test_directory = str(REPO_ROOT / "test")
1759*da0073e9SAndroid Build Coastguard Worker    selected_tests = get_selected_tests(options)
1760*da0073e9SAndroid Build Coastguard Worker
1761*da0073e9SAndroid Build Coastguard Worker    test_prioritizations = import_results()
1762*da0073e9SAndroid Build Coastguard Worker    test_prioritizations.amend_tests(selected_tests)
1763*da0073e9SAndroid Build Coastguard Worker
1764*da0073e9SAndroid Build Coastguard Worker    os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
1765*da0073e9SAndroid Build Coastguard Worker
1766*da0073e9SAndroid Build Coastguard Worker    if options.coverage and not PYTORCH_COLLECT_COVERAGE:
1767*da0073e9SAndroid Build Coastguard Worker        shell(["coverage", "erase"])
1768*da0073e9SAndroid Build Coastguard Worker
1769*da0073e9SAndroid Build Coastguard Worker    if IS_CI:
1770*da0073e9SAndroid Build Coastguard Worker        # downloading test cases configuration to local environment
1771*da0073e9SAndroid Build Coastguard Worker        get_test_case_configs(dirpath=test_directory)
1772*da0073e9SAndroid Build Coastguard Worker
1773*da0073e9SAndroid Build Coastguard Worker    test_file_times_dict = load_test_file_times()
1774*da0073e9SAndroid Build Coastguard Worker    test_class_times_dict = load_test_class_times()
1775*da0073e9SAndroid Build Coastguard Worker
1776*da0073e9SAndroid Build Coastguard Worker    class TestBatch:
1777*da0073e9SAndroid Build Coastguard Worker        """Defines a set of tests with similar priority that should be run together on the current shard"""
1778*da0073e9SAndroid Build Coastguard Worker
1779*da0073e9SAndroid Build Coastguard Worker        name: str
1780*da0073e9SAndroid Build Coastguard Worker        sharded_tests: List[ShardedTest]
1781*da0073e9SAndroid Build Coastguard Worker        failures: List[TestFailure]
1782*da0073e9SAndroid Build Coastguard Worker
1783*da0073e9SAndroid Build Coastguard Worker        def __init__(
1784*da0073e9SAndroid Build Coastguard Worker            self, name: str, raw_tests: Sequence[TestRun], should_sort_shard: bool
1785*da0073e9SAndroid Build Coastguard Worker        ):
1786*da0073e9SAndroid Build Coastguard Worker            self.name = name
1787*da0073e9SAndroid Build Coastguard Worker            self.failures = []
1788*da0073e9SAndroid Build Coastguard Worker            self.time, self.sharded_tests = do_sharding(
1789*da0073e9SAndroid Build Coastguard Worker                options,
1790*da0073e9SAndroid Build Coastguard Worker                raw_tests,
1791*da0073e9SAndroid Build Coastguard Worker                test_file_times_dict,
1792*da0073e9SAndroid Build Coastguard Worker                test_class_times_dict,
1793*da0073e9SAndroid Build Coastguard Worker                sort_by_time=should_sort_shard,
1794*da0073e9SAndroid Build Coastguard Worker            )
1795*da0073e9SAndroid Build Coastguard Worker
1796*da0073e9SAndroid Build Coastguard Worker        def __str__(self):
1797*da0073e9SAndroid Build Coastguard Worker            s = f"Name: {self.name} (est. time: {round(self.time / 60, 2)}min)\n"
1798*da0073e9SAndroid Build Coastguard Worker            serial = [test for test in self.sharded_tests if must_serial(test)]
1799*da0073e9SAndroid Build Coastguard Worker            parallel = [test for test in self.sharded_tests if not must_serial(test)]
1800*da0073e9SAndroid Build Coastguard Worker            s += f"  Serial tests ({len(serial)}):\n"
1801*da0073e9SAndroid Build Coastguard Worker            s += "".join(f"    {test}\n" for test in serial)
1802*da0073e9SAndroid Build Coastguard Worker            s += f"  Parallel tests ({len(parallel)}):\n"
1803*da0073e9SAndroid Build Coastguard Worker            s += "".join(f"    {test}\n" for test in parallel)
1804*da0073e9SAndroid Build Coastguard Worker            return s.strip()
1805*da0073e9SAndroid Build Coastguard Worker
1806*da0073e9SAndroid Build Coastguard Worker    percent_to_run = 25 if options.enable_td else 100
1807*da0073e9SAndroid Build Coastguard Worker    print_to_stderr(
1808*da0073e9SAndroid Build Coastguard Worker        f"Running {percent_to_run}% of tests based on TD"
1809*da0073e9SAndroid Build Coastguard Worker        if options.enable_td
1810*da0073e9SAndroid Build Coastguard Worker        else "Running all tests"
1811*da0073e9SAndroid Build Coastguard Worker    )
1812*da0073e9SAndroid Build Coastguard Worker    include, exclude = test_prioritizations.get_top_per_tests(percent_to_run)
1813*da0073e9SAndroid Build Coastguard Worker
1814*da0073e9SAndroid Build Coastguard Worker    test_batch = TestBatch("tests to run", include, False)
1815*da0073e9SAndroid Build Coastguard Worker    test_batch_exclude = TestBatch("excluded", exclude, True)
1816*da0073e9SAndroid Build Coastguard Worker    if IS_CI:
1817*da0073e9SAndroid Build Coastguard Worker        gen_ci_artifact([x.to_json() for x in include], [x.to_json() for x in exclude])
1818*da0073e9SAndroid Build Coastguard Worker
1819*da0073e9SAndroid Build Coastguard Worker    print_to_stderr(f"Running parallel tests on {NUM_PROCS} processes")
1820*da0073e9SAndroid Build Coastguard Worker    print_to_stderr(test_batch)
1821*da0073e9SAndroid Build Coastguard Worker    print_to_stderr(test_batch_exclude)
1822*da0073e9SAndroid Build Coastguard Worker
1823*da0073e9SAndroid Build Coastguard Worker    if options.dry_run:
1824*da0073e9SAndroid Build Coastguard Worker        return
1825*da0073e9SAndroid Build Coastguard Worker
1826*da0073e9SAndroid Build Coastguard Worker    if options.dynamo:
1827*da0073e9SAndroid Build Coastguard Worker        os.environ["PYTORCH_TEST_WITH_DYNAMO"] = "1"
1828*da0073e9SAndroid Build Coastguard Worker
1829*da0073e9SAndroid Build Coastguard Worker    elif options.inductor:
1830*da0073e9SAndroid Build Coastguard Worker        os.environ["PYTORCH_TEST_WITH_INDUCTOR"] = "1"
1831*da0073e9SAndroid Build Coastguard Worker
1832*da0073e9SAndroid Build Coastguard Worker    if not options.no_translation_validation:
1833*da0073e9SAndroid Build Coastguard Worker        os.environ["PYTORCH_TEST_WITH_TV"] = "1"
1834*da0073e9SAndroid Build Coastguard Worker
1835*da0073e9SAndroid Build Coastguard Worker    try:
1836*da0073e9SAndroid Build Coastguard Worker        # Actually run the tests
1837*da0073e9SAndroid Build Coastguard Worker        start_time = time.time()
1838*da0073e9SAndroid Build Coastguard Worker        run_tests(
1839*da0073e9SAndroid Build Coastguard Worker            test_batch.sharded_tests, test_directory, options, test_batch.failures
1840*da0073e9SAndroid Build Coastguard Worker        )
1841*da0073e9SAndroid Build Coastguard Worker        elapsed_time = time.time() - start_time
1842*da0073e9SAndroid Build Coastguard Worker        print_to_stderr(
1843*da0073e9SAndroid Build Coastguard Worker            f"Running test batch '{test_batch.name}' cost {round(elapsed_time, 2)} seconds"
1844*da0073e9SAndroid Build Coastguard Worker        )
1845*da0073e9SAndroid Build Coastguard Worker
1846*da0073e9SAndroid Build Coastguard Worker    finally:
1847*da0073e9SAndroid Build Coastguard Worker        if options.coverage:
1848*da0073e9SAndroid Build Coastguard Worker            from coverage import Coverage
1849*da0073e9SAndroid Build Coastguard Worker
1850*da0073e9SAndroid Build Coastguard Worker            with set_cwd(test_directory):
1851*da0073e9SAndroid Build Coastguard Worker                cov = Coverage()
1852*da0073e9SAndroid Build Coastguard Worker                if PYTORCH_COLLECT_COVERAGE:
1853*da0073e9SAndroid Build Coastguard Worker                    cov.load()
1854*da0073e9SAndroid Build Coastguard Worker                cov.combine(strict=False)
1855*da0073e9SAndroid Build Coastguard Worker                cov.save()
1856*da0073e9SAndroid Build Coastguard Worker                if not PYTORCH_COLLECT_COVERAGE:
1857*da0073e9SAndroid Build Coastguard Worker                    cov.html_report()
1858*da0073e9SAndroid Build Coastguard Worker
1859*da0073e9SAndroid Build Coastguard Worker        all_failures = test_batch.failures
1860*da0073e9SAndroid Build Coastguard Worker
1861*da0073e9SAndroid Build Coastguard Worker        if IS_CI:
1862*da0073e9SAndroid Build Coastguard Worker            for test, _ in all_failures:
1863*da0073e9SAndroid Build Coastguard Worker                test_stats = test_prioritizations.get_test_stats(test)
1864*da0073e9SAndroid Build Coastguard Worker                print_to_stderr("Emiting td_test_failure_stats_v2")
1865*da0073e9SAndroid Build Coastguard Worker                emit_metric(
1866*da0073e9SAndroid Build Coastguard Worker                    "td_test_failure_stats_v2",
1867*da0073e9SAndroid Build Coastguard Worker                    {
1868*da0073e9SAndroid Build Coastguard Worker                        "selected_tests": selected_tests,
1869*da0073e9SAndroid Build Coastguard Worker                        "failure": str(test),
1870*da0073e9SAndroid Build Coastguard Worker                        **test_stats,
1871*da0073e9SAndroid Build Coastguard Worker                    },
1872*da0073e9SAndroid Build Coastguard Worker                )
1873*da0073e9SAndroid Build Coastguard Worker            gen_additional_test_failures_file(
1874*da0073e9SAndroid Build Coastguard Worker                [test.test_file for test, _ in all_failures]
1875*da0073e9SAndroid Build Coastguard Worker            )
1876*da0073e9SAndroid Build Coastguard Worker
1877*da0073e9SAndroid Build Coastguard Worker    if len(all_failures):
1878*da0073e9SAndroid Build Coastguard Worker        for _, err in all_failures:
1879*da0073e9SAndroid Build Coastguard Worker            print_to_stderr(err)
1880*da0073e9SAndroid Build Coastguard Worker
1881*da0073e9SAndroid Build Coastguard Worker        # A disabled test is expected to fail, so there is no need to report a failure here
1882*da0073e9SAndroid Build Coastguard Worker        if not RERUN_DISABLED_TESTS:
1883*da0073e9SAndroid Build Coastguard Worker            sys.exit(1)
1884*da0073e9SAndroid Build Coastguard Worker
1885*da0073e9SAndroid Build Coastguard Worker
1886*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
1887*da0073e9SAndroid Build Coastguard Worker    main()
1888