1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport argparse 4*da0073e9SAndroid Build Coastguard Workerimport copy 5*da0073e9SAndroid Build Coastguard Workerimport glob 6*da0073e9SAndroid Build Coastguard Workerimport json 7*da0073e9SAndroid Build Coastguard Workerimport os 8*da0073e9SAndroid Build Coastguard Workerimport re 9*da0073e9SAndroid Build Coastguard Workerimport shutil 10*da0073e9SAndroid Build Coastguard Workerimport signal 11*da0073e9SAndroid Build Coastguard Workerimport subprocess 12*da0073e9SAndroid Build Coastguard Workerimport sys 13*da0073e9SAndroid Build Coastguard Workerimport tempfile 14*da0073e9SAndroid Build Coastguard Workerimport time 15*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict 16*da0073e9SAndroid Build Coastguard Workerfrom contextlib import ExitStack 17*da0073e9SAndroid Build Coastguard Workerfrom datetime import datetime 18*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path 19*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, cast, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Workerimport pkg_resources 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Workerimport torch 24*da0073e9SAndroid Build Coastguard Workerimport torch.distributed as dist 25*da0073e9SAndroid Build Coastguard Workerfrom torch.multiprocessing import current_process, get_context 26*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 27*da0073e9SAndroid Build Coastguard Worker get_report_path, 28*da0073e9SAndroid Build Coastguard Worker IS_CI, 29*da0073e9SAndroid Build Coastguard Worker IS_MACOS, 30*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS, 31*da0073e9SAndroid Build Coastguard Worker retry_shell, 32*da0073e9SAndroid Build Coastguard Worker set_cwd, 33*da0073e9SAndroid Build Coastguard Worker shell, 34*da0073e9SAndroid Build Coastguard Worker TEST_CUDA, 35*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ASAN, 36*da0073e9SAndroid Build Coastguard Worker TEST_WITH_CROSSREF, 37*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM, 38*da0073e9SAndroid Build Coastguard Worker TEST_WITH_SLOW_GRADCHECK, 39*da0073e9SAndroid Build Coastguard Worker) 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker# using tools/ to optimize test run. 43*da0073e9SAndroid Build Coastguard WorkerREPO_ROOT = Path(__file__).resolve().parent.parent 44*da0073e9SAndroid Build Coastguard Workersys.path.insert(0, str(REPO_ROOT)) 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Workerfrom tools.stats.import_test_stats import ( 47*da0073e9SAndroid Build Coastguard Worker ADDITIONAL_CI_FILES_FOLDER, 48*da0073e9SAndroid Build Coastguard Worker TEST_CLASS_TIMES_FILE, 49*da0073e9SAndroid Build Coastguard Worker TEST_TIMES_FILE, 50*da0073e9SAndroid Build Coastguard Worker) 51*da0073e9SAndroid Build Coastguard Workerfrom tools.stats.upload_metrics import add_global_metric, emit_metric 52*da0073e9SAndroid Build Coastguard Workerfrom tools.testing.discover_tests import ( 53*da0073e9SAndroid Build Coastguard Worker CPP_TEST_PATH, 54*da0073e9SAndroid Build Coastguard Worker CPP_TEST_PREFIX, 55*da0073e9SAndroid Build Coastguard Worker CPP_TESTS_DIR, 56*da0073e9SAndroid Build Coastguard Worker parse_test_module, 57*da0073e9SAndroid Build Coastguard Worker TESTS, 58*da0073e9SAndroid Build Coastguard Worker) 59*da0073e9SAndroid Build Coastguard Workerfrom tools.testing.do_target_determination_for_s3 import import_results 60*da0073e9SAndroid Build Coastguard Workerfrom tools.testing.target_determination.gen_artifact import gen_ci_artifact 61*da0073e9SAndroid Build Coastguard Workerfrom tools.testing.target_determination.heuristics.previously_failed_in_pr import ( 62*da0073e9SAndroid Build Coastguard Worker gen_additional_test_failures_file, 63*da0073e9SAndroid Build Coastguard Worker) 64*da0073e9SAndroid Build Coastguard Workerfrom tools.testing.target_determination.heuristics.utils import get_pr_number 65*da0073e9SAndroid Build Coastguard Workerfrom tools.testing.test_run import TestRun 66*da0073e9SAndroid Build Coastguard Workerfrom tools.testing.test_selections import ( 67*da0073e9SAndroid Build Coastguard Worker calculate_shards, 68*da0073e9SAndroid Build Coastguard Worker get_test_case_configs, 69*da0073e9SAndroid Build Coastguard Worker NUM_PROCS, 70*da0073e9SAndroid Build Coastguard Worker ShardedTest, 71*da0073e9SAndroid Build Coastguard Worker THRESHOLD, 72*da0073e9SAndroid Build Coastguard Worker) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker# Make sure to remove REPO_ROOT after import is done 76*da0073e9SAndroid Build Coastguard Workersys.path.remove(str(REPO_ROOT)) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard WorkerHAVE_TEST_SELECTION_TOOLS = True 80*da0073e9SAndroid Build Coastguard WorkerTEST_CONFIG = os.getenv("TEST_CONFIG", "") 81*da0073e9SAndroid Build Coastguard WorkerBUILD_ENVIRONMENT = os.getenv("BUILD_ENVIRONMENT", "") 82*da0073e9SAndroid Build Coastguard WorkerRERUN_DISABLED_TESTS = os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1" 83*da0073e9SAndroid Build Coastguard WorkerDISTRIBUTED_TEST_PREFIX = "distributed" 84*da0073e9SAndroid Build Coastguard WorkerINDUCTOR_TEST_PREFIX = "inductor" 85*da0073e9SAndroid Build Coastguard WorkerIS_SLOW = "slow" in TEST_CONFIG or "slow" in BUILD_ENVIRONMENT 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker# Note [ROCm parallel CI testing] 89*da0073e9SAndroid Build Coastguard Worker# https://github.com/pytorch/pytorch/pull/85770 added file-granularity parallel testing. 90*da0073e9SAndroid Build Coastguard Worker# In .ci/pytorch/test.sh, TEST_CONFIG == "default", CUDA and HIP_VISIBLE_DEVICES is set to 0. 91*da0073e9SAndroid Build Coastguard Worker# This results in multiple test files sharing the same GPU. 92*da0073e9SAndroid Build Coastguard Worker# This should be a supported use case for ROCm, but it exposed issues in the kernel driver resulting in hangs. 93*da0073e9SAndroid Build Coastguard Worker# See https://github.com/pytorch/pytorch/issues/90940. 94*da0073e9SAndroid Build Coastguard Worker# 95*da0073e9SAndroid Build Coastguard Worker# Further, ROCm self-hosted runners have up to 4 GPUs. 96*da0073e9SAndroid Build Coastguard Worker# Device visibility was set to 0 to match CUDA test behavior, but this was wasting available GPU resources. 97*da0073e9SAndroid Build Coastguard Worker# Assigning each Pool worker their own dedicated GPU avoids the ROCm oversubscription issues. 98*da0073e9SAndroid Build Coastguard Worker# This should also result in better overall wall clock time since all GPUs can be utilized. 99*da0073e9SAndroid Build Coastguard Workerdef maybe_set_hip_visible_devies(): 100*da0073e9SAndroid Build Coastguard Worker # Special handling of ROCm GHA runners for parallel (file granularity) tests. 101*da0073e9SAndroid Build Coastguard Worker if torch.version.hip: 102*da0073e9SAndroid Build Coastguard Worker p = current_process() 103*da0073e9SAndroid Build Coastguard Worker if p.name != "MainProcess": 104*da0073e9SAndroid Build Coastguard Worker # this is a Process from a parallel Pool, not the MainProcess 105*da0073e9SAndroid Build Coastguard Worker os.environ["HIP_VISIBLE_DEVICES"] = str(p._identity[0] % NUM_PROCS) 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Workerdef strtobool(s): 109*da0073e9SAndroid Build Coastguard Worker return s.lower() not in {"", "0", "false", "off"} 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Workerclass TestChoices(list): 113*da0073e9SAndroid Build Coastguard Worker def __init__(self, *args, **kwargs): 114*da0073e9SAndroid Build Coastguard Worker super().__init__(args[0]) 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker def __contains__(self, item): 117*da0073e9SAndroid Build Coastguard Worker return list.__contains__(self, parse_test_module(item)) 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard WorkerFSDP_TEST = [test for test in TESTS if test.startswith("distributed/fsdp")] 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard WorkerWINDOWS_BLOCKLIST = [ 123*da0073e9SAndroid Build Coastguard Worker "distributed/nn/jit/test_instantiator", 124*da0073e9SAndroid Build Coastguard Worker "distributed/rpc/test_faulty_agent", 125*da0073e9SAndroid Build Coastguard Worker "distributed/rpc/test_tensorpipe_agent", 126*da0073e9SAndroid Build Coastguard Worker "distributed/rpc/test_share_memory", 127*da0073e9SAndroid Build Coastguard Worker "distributed/rpc/cuda/test_tensorpipe_agent", 128*da0073e9SAndroid Build Coastguard Worker "distributed/pipeline/sync/skip/test_api", 129*da0073e9SAndroid Build Coastguard Worker "distributed/pipeline/sync/skip/test_gpipe", 130*da0073e9SAndroid Build Coastguard Worker "distributed/pipeline/sync/skip/test_inspect_skip_layout", 131*da0073e9SAndroid Build Coastguard Worker "distributed/pipeline/sync/skip/test_leak", 132*da0073e9SAndroid Build Coastguard Worker "distributed/pipeline/sync/skip/test_portal", 133*da0073e9SAndroid Build Coastguard Worker "distributed/pipeline/sync/skip/test_stash_pop", 134*da0073e9SAndroid Build Coastguard Worker "distributed/pipeline/sync/skip/test_tracker", 135*da0073e9SAndroid Build Coastguard Worker "distributed/pipeline/sync/skip/test_verify_skippables", 136*da0073e9SAndroid Build Coastguard Worker "distributed/pipeline/sync/test_balance", 137*da0073e9SAndroid Build Coastguard Worker "distributed/pipeline/sync/test_bugs", 138*da0073e9SAndroid Build Coastguard Worker "distributed/pipeline/sync/test_checkpoint", 139*da0073e9SAndroid Build Coastguard Worker "distributed/pipeline/sync/test_copy", 140*da0073e9SAndroid Build Coastguard Worker "distributed/pipeline/sync/test_deferred_batch_norm", 141*da0073e9SAndroid Build Coastguard Worker "distributed/pipeline/sync/test_dependency", 142*da0073e9SAndroid Build Coastguard Worker "distributed/pipeline/sync/test_inplace", 143*da0073e9SAndroid Build Coastguard Worker "distributed/pipeline/sync/test_microbatch", 144*da0073e9SAndroid Build Coastguard Worker "distributed/pipeline/sync/test_phony", 145*da0073e9SAndroid Build Coastguard Worker "distributed/pipeline/sync/test_pipe", 146*da0073e9SAndroid Build Coastguard Worker "distributed/pipeline/sync/test_pipeline", 147*da0073e9SAndroid Build Coastguard Worker "distributed/pipeline/sync/test_stream", 148*da0073e9SAndroid Build Coastguard Worker "distributed/pipeline/sync/test_transparency", 149*da0073e9SAndroid Build Coastguard Worker "distributed/pipeline/sync/test_worker", 150*da0073e9SAndroid Build Coastguard Worker "distributed/elastic/agent/server/test/api_test", 151*da0073e9SAndroid Build Coastguard Worker "distributed/elastic/multiprocessing/api_test", 152*da0073e9SAndroid Build Coastguard Worker "distributed/_shard/checkpoint/test_checkpoint" 153*da0073e9SAndroid Build Coastguard Worker "distributed/_shard/checkpoint/test_file_system_checkpoint" 154*da0073e9SAndroid Build Coastguard Worker "distributed/_shard/sharding_spec/test_sharding_spec", 155*da0073e9SAndroid Build Coastguard Worker "distributed/_shard/sharding_plan/test_sharding_plan", 156*da0073e9SAndroid Build Coastguard Worker "distributed/_shard/sharded_tensor/test_sharded_tensor", 157*da0073e9SAndroid Build Coastguard Worker "distributed/_shard/sharded_tensor/test_sharded_tensor_reshard", 158*da0073e9SAndroid Build Coastguard Worker "distributed/_shard/sharded_tensor/ops/test_embedding", 159*da0073e9SAndroid Build Coastguard Worker "distributed/_shard/sharded_tensor/ops/test_embedding_bag", 160*da0073e9SAndroid Build Coastguard Worker "distributed/_shard/sharded_tensor/ops/test_binary_cmp", 161*da0073e9SAndroid Build Coastguard Worker "distributed/_shard/sharded_tensor/ops/test_init", 162*da0073e9SAndroid Build Coastguard Worker "distributed/_shard/sharded_optim/test_sharded_optim", 163*da0073e9SAndroid Build Coastguard Worker] + FSDP_TEST 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard WorkerROCM_BLOCKLIST = [ 166*da0073e9SAndroid Build Coastguard Worker "distributed/rpc/test_faulty_agent", 167*da0073e9SAndroid Build Coastguard Worker "distributed/rpc/test_tensorpipe_agent", 168*da0073e9SAndroid Build Coastguard Worker "distributed/rpc/test_share_memory", 169*da0073e9SAndroid Build Coastguard Worker "distributed/rpc/cuda/test_tensorpipe_agent", 170*da0073e9SAndroid Build Coastguard Worker "distributed/_shard/checkpoint/test_checkpoint" 171*da0073e9SAndroid Build Coastguard Worker "distributed/_shard/checkpoint/test_file_system_checkpoint" 172*da0073e9SAndroid Build Coastguard Worker "distributed/_shard/sharding_spec/test_sharding_spec", 173*da0073e9SAndroid Build Coastguard Worker "distributed/_shard/sharding_plan/test_sharding_plan", 174*da0073e9SAndroid Build Coastguard Worker "distributed/_shard/sharded_tensor/test_sharded_tensor", 175*da0073e9SAndroid Build Coastguard Worker "distributed/_shard/sharded_tensor/test_sharded_tensor_reshard", 176*da0073e9SAndroid Build Coastguard Worker "distributed/_shard/sharded_tensor/ops/test_embedding", 177*da0073e9SAndroid Build Coastguard Worker "distributed/_shard/sharded_tensor/ops/test_embedding_bag", 178*da0073e9SAndroid Build Coastguard Worker "distributed/_shard/sharded_tensor/ops/test_binary_cmp", 179*da0073e9SAndroid Build Coastguard Worker "distributed/_shard/sharded_tensor/ops/test_init", 180*da0073e9SAndroid Build Coastguard Worker "distributed/_shard/sharded_optim/test_sharded_optim", 181*da0073e9SAndroid Build Coastguard Worker "test_determination", 182*da0073e9SAndroid Build Coastguard Worker "test_jit_legacy", 183*da0073e9SAndroid Build Coastguard Worker "test_cuda_nvml_based_avail", 184*da0073e9SAndroid Build Coastguard Worker "test_jit_cuda_fuser", 185*da0073e9SAndroid Build Coastguard Worker "distributed/_tensor/test_attention", 186*da0073e9SAndroid Build Coastguard Worker] 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard WorkerXPU_BLOCKLIST = [ 189*da0073e9SAndroid Build Coastguard Worker "test_autograd", 190*da0073e9SAndroid Build Coastguard Worker "profiler/test_cpp_thread", 191*da0073e9SAndroid Build Coastguard Worker "profiler/test_execution_trace", 192*da0073e9SAndroid Build Coastguard Worker "profiler/test_memory_profiler", 193*da0073e9SAndroid Build Coastguard Worker "profiler/test_profiler", 194*da0073e9SAndroid Build Coastguard Worker "profiler/test_profiler_tree", 195*da0073e9SAndroid Build Coastguard Worker "profiler/test_record_function", 196*da0073e9SAndroid Build Coastguard Worker "profiler/test_torch_tidy", 197*da0073e9SAndroid Build Coastguard Worker] 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard WorkerXPU_TEST = [ 200*da0073e9SAndroid Build Coastguard Worker "test_xpu", 201*da0073e9SAndroid Build Coastguard Worker] 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker# The tests inside these files should never be run in parallel with each other 204*da0073e9SAndroid Build Coastguard WorkerRUN_PARALLEL_BLOCKLIST = [ 205*da0073e9SAndroid Build Coastguard Worker "test_cpp_extensions_jit", 206*da0073e9SAndroid Build Coastguard Worker "test_cpp_extensions_open_device_registration", 207*da0073e9SAndroid Build Coastguard Worker "test_cpp_extensions_stream_and_event", 208*da0073e9SAndroid Build Coastguard Worker "test_cpp_extensions_mtia_backend", 209*da0073e9SAndroid Build Coastguard Worker "test_jit_disabled", 210*da0073e9SAndroid Build Coastguard Worker "test_mobile_optimizer", 211*da0073e9SAndroid Build Coastguard Worker "test_multiprocessing", 212*da0073e9SAndroid Build Coastguard Worker "test_multiprocessing_spawn", 213*da0073e9SAndroid Build Coastguard Worker "test_namedtuple_return_api", 214*da0073e9SAndroid Build Coastguard Worker "test_overrides", 215*da0073e9SAndroid Build Coastguard Worker "test_show_pickle", 216*da0073e9SAndroid Build Coastguard Worker "test_tensorexpr", 217*da0073e9SAndroid Build Coastguard Worker "test_cuda_primary_ctx", 218*da0073e9SAndroid Build Coastguard Worker "test_cuda_trace", 219*da0073e9SAndroid Build Coastguard Worker "inductor/test_benchmark_fusion", 220*da0073e9SAndroid Build Coastguard Worker "test_cuda_nvml_based_avail", 221*da0073e9SAndroid Build Coastguard Worker # temporarily sets a global config 222*da0073e9SAndroid Build Coastguard Worker "test_autograd_fallback", 223*da0073e9SAndroid Build Coastguard Worker] + FSDP_TEST 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker# Test files that should always be run serially with other test files, 226*da0073e9SAndroid Build Coastguard Worker# but it's okay if the tests inside them are run in parallel with each other. 227*da0073e9SAndroid Build Coastguard WorkerCI_SERIAL_LIST = [ 228*da0073e9SAndroid Build Coastguard Worker "test_nn", 229*da0073e9SAndroid Build Coastguard Worker "test_fake_tensor", 230*da0073e9SAndroid Build Coastguard Worker "test_cpp_api_parity", 231*da0073e9SAndroid Build Coastguard Worker "test_reductions", 232*da0073e9SAndroid Build Coastguard Worker "test_fx_backends", 233*da0073e9SAndroid Build Coastguard Worker "test_cpp_extensions_jit", 234*da0073e9SAndroid Build Coastguard Worker "test_torch", 235*da0073e9SAndroid Build Coastguard Worker "test_tensor_creation_ops", 236*da0073e9SAndroid Build Coastguard Worker "test_dispatch", 237*da0073e9SAndroid Build Coastguard Worker "test_python_dispatch", # torch.library creation and deletion must be serialized 238*da0073e9SAndroid Build Coastguard Worker "test_spectral_ops", # Cause CUDA illegal memory access https://github.com/pytorch/pytorch/issues/88916 239*da0073e9SAndroid Build Coastguard Worker "nn/test_pooling", 240*da0073e9SAndroid Build Coastguard Worker "nn/test_convolution", # Doesn't respect set_per_process_memory_fraction, results in OOM for other tests in slow gradcheck 241*da0073e9SAndroid Build Coastguard Worker "distributions/test_distributions", 242*da0073e9SAndroid Build Coastguard Worker "test_fx", # gets SIGKILL 243*da0073e9SAndroid Build Coastguard Worker "functorch/test_memory_efficient_fusion", # Cause CUDA OOM on ROCm 244*da0073e9SAndroid Build Coastguard Worker "test_utils", # OOM 245*da0073e9SAndroid Build Coastguard Worker "test_sort_and_select", # OOM 246*da0073e9SAndroid Build Coastguard Worker "test_backward_compatible_arguments", # OOM 247*da0073e9SAndroid Build Coastguard Worker "test_autocast", # OOM 248*da0073e9SAndroid Build Coastguard Worker "test_native_mha", # OOM 249*da0073e9SAndroid Build Coastguard Worker "test_module_hooks", # OOM 250*da0073e9SAndroid Build Coastguard Worker "inductor/test_max_autotune", 251*da0073e9SAndroid Build Coastguard Worker "inductor/test_cutlass_backend", # slow due to many nvcc compilation steps, 252*da0073e9SAndroid Build Coastguard Worker "inductor/test_flex_attention", # OOM 253*da0073e9SAndroid Build Coastguard Worker] 254*da0073e9SAndroid Build Coastguard Worker# A subset of onnx tests that cannot run in parallel due to high memory usage. 255*da0073e9SAndroid Build Coastguard WorkerONNX_SERIAL_LIST = [ 256*da0073e9SAndroid Build Coastguard Worker "onnx/test_models", 257*da0073e9SAndroid Build Coastguard Worker "onnx/test_models_quantized_onnxruntime", 258*da0073e9SAndroid Build Coastguard Worker "onnx/test_models_onnxruntime", 259*da0073e9SAndroid Build Coastguard Worker "onnx/test_custom_ops", 260*da0073e9SAndroid Build Coastguard Worker "onnx/test_utility_funs", 261*da0073e9SAndroid Build Coastguard Worker] 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker# A subset of our TEST list that validates PyTorch's ops, modules, and autograd function as expected 264*da0073e9SAndroid Build Coastguard WorkerCORE_TEST_LIST = [ 265*da0073e9SAndroid Build Coastguard Worker "test_autograd", 266*da0073e9SAndroid Build Coastguard Worker "test_autograd_fallback", 267*da0073e9SAndroid Build Coastguard Worker "test_modules", 268*da0073e9SAndroid Build Coastguard Worker "test_nn", 269*da0073e9SAndroid Build Coastguard Worker "test_ops", 270*da0073e9SAndroid Build Coastguard Worker "test_ops_gradients", 271*da0073e9SAndroid Build Coastguard Worker "test_ops_fwd_gradients", 272*da0073e9SAndroid Build Coastguard Worker "test_ops_jit", 273*da0073e9SAndroid Build Coastguard Worker "test_torch", 274*da0073e9SAndroid Build Coastguard Worker] 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker# if a test file takes longer than 5 min, we add it to TARGET_DET_LIST 278*da0073e9SAndroid Build Coastguard WorkerSLOW_TEST_THRESHOLD = 300 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard WorkerDISTRIBUTED_TESTS_CONFIG = {} 281*da0073e9SAndroid Build Coastguard Worker 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Workerif dist.is_available(): 284*da0073e9SAndroid Build Coastguard Worker DISTRIBUTED_TESTS_CONFIG["test"] = {"WORLD_SIZE": "1"} 285*da0073e9SAndroid Build Coastguard Worker if not TEST_WITH_ROCM and dist.is_mpi_available(): 286*da0073e9SAndroid Build Coastguard Worker DISTRIBUTED_TESTS_CONFIG["mpi"] = { 287*da0073e9SAndroid Build Coastguard Worker "WORLD_SIZE": "3", 288*da0073e9SAndroid Build Coastguard Worker "TEST_REPORT_SOURCE_OVERRIDE": "dist-mpi", 289*da0073e9SAndroid Build Coastguard Worker } 290*da0073e9SAndroid Build Coastguard Worker if dist.is_nccl_available(): 291*da0073e9SAndroid Build Coastguard Worker DISTRIBUTED_TESTS_CONFIG["nccl"] = { 292*da0073e9SAndroid Build Coastguard Worker "WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3", 293*da0073e9SAndroid Build Coastguard Worker "TEST_REPORT_SOURCE_OVERRIDE": "dist-nccl", 294*da0073e9SAndroid Build Coastguard Worker } 295*da0073e9SAndroid Build Coastguard Worker if dist.is_gloo_available(): 296*da0073e9SAndroid Build Coastguard Worker DISTRIBUTED_TESTS_CONFIG["gloo"] = { 297*da0073e9SAndroid Build Coastguard Worker "WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3", 298*da0073e9SAndroid Build Coastguard Worker "TEST_REPORT_SOURCE_OVERRIDE": "dist-gloo", 299*da0073e9SAndroid Build Coastguard Worker } 300*da0073e9SAndroid Build Coastguard Worker if dist.is_ucc_available(): 301*da0073e9SAndroid Build Coastguard Worker DISTRIBUTED_TESTS_CONFIG["ucc"] = { 302*da0073e9SAndroid Build Coastguard Worker "WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3", 303*da0073e9SAndroid Build Coastguard Worker "TEST_REPORT_SOURCE_OVERRIDE": "dist-ucc", 304*da0073e9SAndroid Build Coastguard Worker "UCX_TLS": "tcp,cuda", 305*da0073e9SAndroid Build Coastguard Worker "UCC_TLS": "nccl,ucp,cuda", 306*da0073e9SAndroid Build Coastguard Worker "UCC_TL_UCP_TUNE": "cuda:0", # don't use UCP TL on CUDA as it is not well supported 307*da0073e9SAndroid Build Coastguard Worker "UCC_EC_CUDA_USE_COOPERATIVE_LAUNCH": "n", # CI nodes (M60) fail if it is on 308*da0073e9SAndroid Build Coastguard Worker } 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker# https://stackoverflow.com/questions/2549939/get-signal-names-from-numbers-in-python 311*da0073e9SAndroid Build Coastguard WorkerSIGNALS_TO_NAMES_DICT = { 312*da0073e9SAndroid Build Coastguard Worker getattr(signal, n): n for n in dir(signal) if n.startswith("SIG") and "_" not in n 313*da0073e9SAndroid Build Coastguard Worker} 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard WorkerCPP_EXTENSIONS_ERROR = """ 316*da0073e9SAndroid Build Coastguard WorkerNinja (https://ninja-build.org) is required for some of the C++ extensions 317*da0073e9SAndroid Build Coastguard Workertests, but it could not be found. Install ninja with `pip install ninja` 318*da0073e9SAndroid Build Coastguard Workeror `conda install ninja`. Alternatively, disable said tests with 319*da0073e9SAndroid Build Coastguard Worker`run_test.py --exclude test_cpp_extensions_aot_ninja test_cpp_extensions_jit`. 320*da0073e9SAndroid Build Coastguard Worker""" 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard WorkerPYTORCH_COLLECT_COVERAGE = bool(os.environ.get("PYTORCH_COLLECT_COVERAGE")) 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard WorkerJIT_EXECUTOR_TESTS = [ 325*da0073e9SAndroid Build Coastguard Worker "test_jit_profiling", 326*da0073e9SAndroid Build Coastguard Worker "test_jit_legacy", 327*da0073e9SAndroid Build Coastguard Worker "test_jit_fuser_legacy", 328*da0073e9SAndroid Build Coastguard Worker] 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard WorkerINDUCTOR_TESTS = [test for test in TESTS if test.startswith(INDUCTOR_TEST_PREFIX)] 331*da0073e9SAndroid Build Coastguard WorkerDISTRIBUTED_TESTS = [test for test in TESTS if test.startswith(DISTRIBUTED_TEST_PREFIX)] 332*da0073e9SAndroid Build Coastguard WorkerTORCH_EXPORT_TESTS = [test for test in TESTS if test.startswith("export")] 333*da0073e9SAndroid Build Coastguard WorkerFUNCTORCH_TESTS = [test for test in TESTS if test.startswith("functorch")] 334*da0073e9SAndroid Build Coastguard WorkerONNX_TESTS = [test for test in TESTS if test.startswith("onnx")] 335*da0073e9SAndroid Build Coastguard WorkerCPP_TESTS = [test for test in TESTS if test.startswith(CPP_TEST_PREFIX)] 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard WorkerTESTS_REQUIRING_LAPACK = [ 338*da0073e9SAndroid Build Coastguard Worker "distributions/test_constraints", 339*da0073e9SAndroid Build Coastguard Worker "distributions/test_distributions", 340*da0073e9SAndroid Build Coastguard Worker] 341*da0073e9SAndroid Build Coastguard Worker 342*da0073e9SAndroid Build Coastguard Worker# These are just the slowest ones, this isn't an exhaustive list. 343*da0073e9SAndroid Build Coastguard WorkerTESTS_NOT_USING_GRADCHECK = [ 344*da0073e9SAndroid Build Coastguard Worker # Note that you should use skipIfSlowGradcheckEnv if you do not wish to 345*da0073e9SAndroid Build Coastguard Worker # skip all the tests in that file, e.g. test_mps 346*da0073e9SAndroid Build Coastguard Worker "doctests", 347*da0073e9SAndroid Build Coastguard Worker "test_meta", 348*da0073e9SAndroid Build Coastguard Worker "test_hub", 349*da0073e9SAndroid Build Coastguard Worker "test_fx", 350*da0073e9SAndroid Build Coastguard Worker "test_decomp", 351*da0073e9SAndroid Build Coastguard Worker "test_cpp_extensions_jit", 352*da0073e9SAndroid Build Coastguard Worker "test_jit", 353*da0073e9SAndroid Build Coastguard Worker "test_ops", 354*da0073e9SAndroid Build Coastguard Worker "test_ops_jit", 355*da0073e9SAndroid Build Coastguard Worker "dynamo/test_recompile_ux", 356*da0073e9SAndroid Build Coastguard Worker "inductor/test_smoke", 357*da0073e9SAndroid Build Coastguard Worker "test_quantization", 358*da0073e9SAndroid Build Coastguard Worker] 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Workerdef print_to_stderr(message): 362*da0073e9SAndroid Build Coastguard Worker print(message, file=sys.stderr) 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Workerdef get_executable_command(options, disable_coverage=False, is_cpp_test=False): 366*da0073e9SAndroid Build Coastguard Worker if options.coverage and not disable_coverage: 367*da0073e9SAndroid Build Coastguard Worker if not is_cpp_test: 368*da0073e9SAndroid Build Coastguard Worker executable = ["coverage", "run", "--parallel-mode", "--source=torch"] 369*da0073e9SAndroid Build Coastguard Worker else: 370*da0073e9SAndroid Build Coastguard Worker # TODO: C++ with coverage is not yet supported 371*da0073e9SAndroid Build Coastguard Worker executable = [] 372*da0073e9SAndroid Build Coastguard Worker else: 373*da0073e9SAndroid Build Coastguard Worker if not is_cpp_test: 374*da0073e9SAndroid Build Coastguard Worker executable = [sys.executable, "-bb"] 375*da0073e9SAndroid Build Coastguard Worker else: 376*da0073e9SAndroid Build Coastguard Worker executable = ["pytest"] 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker return executable 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Workerdef run_test( 382*da0073e9SAndroid Build Coastguard Worker test_module: ShardedTest, 383*da0073e9SAndroid Build Coastguard Worker test_directory, 384*da0073e9SAndroid Build Coastguard Worker options, 385*da0073e9SAndroid Build Coastguard Worker launcher_cmd=None, 386*da0073e9SAndroid Build Coastguard Worker extra_unittest_args=None, 387*da0073e9SAndroid Build Coastguard Worker env=None, 388*da0073e9SAndroid Build Coastguard Worker print_log=True, 389*da0073e9SAndroid Build Coastguard Worker) -> int: 390*da0073e9SAndroid Build Coastguard Worker scribe_token = os.getenv("SCRIBE_GRAPHQL_ACCESS_TOKEN", "") 391*da0073e9SAndroid Build Coastguard Worker if scribe_token: 392*da0073e9SAndroid Build Coastguard Worker print_to_stderr("SCRIBE_GRAPHQL_ACCESS_TOKEN is set") 393*da0073e9SAndroid Build Coastguard Worker else: 394*da0073e9SAndroid Build Coastguard Worker print_to_stderr("SCRIBE_GRAPHQL_ACCESS_TOKEN is NOT set") 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker env = env or os.environ.copy() 397*da0073e9SAndroid Build Coastguard Worker maybe_set_hip_visible_devies() 398*da0073e9SAndroid Build Coastguard Worker unittest_args = options.additional_args.copy() 399*da0073e9SAndroid Build Coastguard Worker test_file = test_module.name 400*da0073e9SAndroid Build Coastguard Worker stepcurrent_key = test_file 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker is_distributed_test = test_file.startswith(DISTRIBUTED_TEST_PREFIX) 403*da0073e9SAndroid Build Coastguard Worker is_cpp_test = test_file.startswith(CPP_TEST_PREFIX) 404*da0073e9SAndroid Build Coastguard Worker # NB: Rerun disabled tests depends on pytest-flakefinder and it doesn't work with 405*da0073e9SAndroid Build Coastguard Worker # pytest-cpp atm. We also don't have support to disable C++ test yet, so it's ok 406*da0073e9SAndroid Build Coastguard Worker # to just return successfully here 407*da0073e9SAndroid Build Coastguard Worker if is_cpp_test and RERUN_DISABLED_TESTS: 408*da0073e9SAndroid Build Coastguard Worker print_to_stderr( 409*da0073e9SAndroid Build Coastguard Worker "Skipping C++ tests when running under RERUN_DISABLED_TESTS mode" 410*da0073e9SAndroid Build Coastguard Worker ) 411*da0073e9SAndroid Build Coastguard Worker return 0 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker if is_cpp_test: 414*da0073e9SAndroid Build Coastguard Worker stepcurrent_key = f"{test_file}_{os.urandom(8).hex()}" 415*da0073e9SAndroid Build Coastguard Worker else: 416*da0073e9SAndroid Build Coastguard Worker unittest_args.extend( 417*da0073e9SAndroid Build Coastguard Worker [ 418*da0073e9SAndroid Build Coastguard Worker f"--shard-id={test_module.shard}", 419*da0073e9SAndroid Build Coastguard Worker f"--num-shards={test_module.num_shards}", 420*da0073e9SAndroid Build Coastguard Worker ] 421*da0073e9SAndroid Build Coastguard Worker ) 422*da0073e9SAndroid Build Coastguard Worker stepcurrent_key = f"{test_file}_{test_module.shard}_{os.urandom(8).hex()}" 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker if options.verbose: 425*da0073e9SAndroid Build Coastguard Worker unittest_args.append(f'-{"v" * options.verbose}') # in case of pytest 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker if test_file in RUN_PARALLEL_BLOCKLIST: 428*da0073e9SAndroid Build Coastguard Worker unittest_args = [ 429*da0073e9SAndroid Build Coastguard Worker arg for arg in unittest_args if not arg.startswith("--run-parallel") 430*da0073e9SAndroid Build Coastguard Worker ] 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker if extra_unittest_args: 433*da0073e9SAndroid Build Coastguard Worker assert isinstance(extra_unittest_args, list) 434*da0073e9SAndroid Build Coastguard Worker unittest_args.extend(extra_unittest_args) 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Worker # If using pytest, replace -f with equivalent -x 437*da0073e9SAndroid Build Coastguard Worker if options.pytest: 438*da0073e9SAndroid Build Coastguard Worker unittest_args.extend( 439*da0073e9SAndroid Build Coastguard Worker get_pytest_args( 440*da0073e9SAndroid Build Coastguard Worker options, 441*da0073e9SAndroid Build Coastguard Worker is_cpp_test=is_cpp_test, 442*da0073e9SAndroid Build Coastguard Worker is_distributed_test=is_distributed_test, 443*da0073e9SAndroid Build Coastguard Worker ) 444*da0073e9SAndroid Build Coastguard Worker ) 445*da0073e9SAndroid Build Coastguard Worker unittest_args.extend(test_module.get_pytest_args()) 446*da0073e9SAndroid Build Coastguard Worker replacement = {"-f": "-x"} 447*da0073e9SAndroid Build Coastguard Worker unittest_args = [replacement.get(arg, arg) for arg in unittest_args] 448*da0073e9SAndroid Build Coastguard Worker 449*da0073e9SAndroid Build Coastguard Worker if options.showlocals: 450*da0073e9SAndroid Build Coastguard Worker if options.pytest: 451*da0073e9SAndroid Build Coastguard Worker unittest_args.extend(["--showlocals", "--tb=long", "--color=yes"]) 452*da0073e9SAndroid Build Coastguard Worker else: 453*da0073e9SAndroid Build Coastguard Worker unittest_args.append("--locals") 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Worker # NB: These features are not available for C++ tests, but there is little incentive 456*da0073e9SAndroid Build Coastguard Worker # to implement it because we have never seen a flaky C++ test before. 457*da0073e9SAndroid Build Coastguard Worker if IS_CI and not is_cpp_test: 458*da0073e9SAndroid Build Coastguard Worker ci_args = ["--import-slow-tests", "--import-disabled-tests"] 459*da0073e9SAndroid Build Coastguard Worker if RERUN_DISABLED_TESTS: 460*da0073e9SAndroid Build Coastguard Worker ci_args.append("--rerun-disabled-tests") 461*da0073e9SAndroid Build Coastguard Worker # use the downloaded test cases configuration, not supported in pytest 462*da0073e9SAndroid Build Coastguard Worker unittest_args.extend(ci_args) 463*da0073e9SAndroid Build Coastguard Worker 464*da0073e9SAndroid Build Coastguard Worker if test_file in PYTEST_SKIP_RETRIES: 465*da0073e9SAndroid Build Coastguard Worker if not options.pytest: 466*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 467*da0073e9SAndroid Build Coastguard Worker "A test running without pytest cannot skip retries using " 468*da0073e9SAndroid Build Coastguard Worker "the PYTEST_SKIP_RETRIES set." 469*da0073e9SAndroid Build Coastguard Worker ) 470*da0073e9SAndroid Build Coastguard Worker unittest_args = [arg for arg in unittest_args if "--reruns" not in arg] 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker # Extra arguments are not supported with pytest 473*da0073e9SAndroid Build Coastguard Worker executable = get_executable_command(options, is_cpp_test=is_cpp_test) 474*da0073e9SAndroid Build Coastguard Worker if not executable: 475*da0073e9SAndroid Build Coastguard Worker # If there is no eligible executable returning here, it means an unsupported 476*da0073e9SAndroid Build Coastguard Worker # case such as coverage for C++ test. So just returning ok makes sense 477*da0073e9SAndroid Build Coastguard Worker return 0 478*da0073e9SAndroid Build Coastguard Worker 479*da0073e9SAndroid Build Coastguard Worker if test_file.startswith(CPP_TEST_PREFIX): 480*da0073e9SAndroid Build Coastguard Worker # C++ tests are not the regular test directory 481*da0073e9SAndroid Build Coastguard Worker if CPP_TESTS_DIR: 482*da0073e9SAndroid Build Coastguard Worker cpp_test = os.path.join( 483*da0073e9SAndroid Build Coastguard Worker CPP_TESTS_DIR, 484*da0073e9SAndroid Build Coastguard Worker test_file.replace(f"{CPP_TEST_PREFIX}/", ""), 485*da0073e9SAndroid Build Coastguard Worker ) 486*da0073e9SAndroid Build Coastguard Worker else: 487*da0073e9SAndroid Build Coastguard Worker cpp_test = os.path.join( 488*da0073e9SAndroid Build Coastguard Worker Path(test_directory).parent, 489*da0073e9SAndroid Build Coastguard Worker CPP_TEST_PATH, 490*da0073e9SAndroid Build Coastguard Worker test_file.replace(f"{CPP_TEST_PREFIX}/", ""), 491*da0073e9SAndroid Build Coastguard Worker ) 492*da0073e9SAndroid Build Coastguard Worker 493*da0073e9SAndroid Build Coastguard Worker argv = [ 494*da0073e9SAndroid Build Coastguard Worker cpp_test if sys.platform != "win32" else cpp_test + ".exe" 495*da0073e9SAndroid Build Coastguard Worker ] + unittest_args 496*da0073e9SAndroid Build Coastguard Worker else: 497*da0073e9SAndroid Build Coastguard Worker # Can't call `python -m unittest test_*` here because it doesn't run code 498*da0073e9SAndroid Build Coastguard Worker # in `if __name__ == '__main__': `. So call `python test_*.py` instead. 499*da0073e9SAndroid Build Coastguard Worker argv = [test_file + ".py"] + unittest_args 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True) 502*da0073e9SAndroid Build Coastguard Worker if options.pipe_logs: 503*da0073e9SAndroid Build Coastguard Worker log_fd, log_path = tempfile.mkstemp( 504*da0073e9SAndroid Build Coastguard Worker dir=REPO_ROOT / "test" / "test-reports", 505*da0073e9SAndroid Build Coastguard Worker prefix=f"{sanitize_file_name(str(test_module))}_", 506*da0073e9SAndroid Build Coastguard Worker suffix="_toprint.log", 507*da0073e9SAndroid Build Coastguard Worker ) 508*da0073e9SAndroid Build Coastguard Worker os.close(log_fd) 509*da0073e9SAndroid Build Coastguard Worker 510*da0073e9SAndroid Build Coastguard Worker command = (launcher_cmd or []) + executable + argv 511*da0073e9SAndroid Build Coastguard Worker should_retry = ( 512*da0073e9SAndroid Build Coastguard Worker "--subprocess" not in command 513*da0073e9SAndroid Build Coastguard Worker and not RERUN_DISABLED_TESTS 514*da0073e9SAndroid Build Coastguard Worker and not is_cpp_test 515*da0073e9SAndroid Build Coastguard Worker and "-n" not in command 516*da0073e9SAndroid Build Coastguard Worker ) 517*da0073e9SAndroid Build Coastguard Worker timeout = ( 518*da0073e9SAndroid Build Coastguard Worker None 519*da0073e9SAndroid Build Coastguard Worker if not options.enable_timeout 520*da0073e9SAndroid Build Coastguard Worker else THRESHOLD * 6 521*da0073e9SAndroid Build Coastguard Worker if IS_SLOW 522*da0073e9SAndroid Build Coastguard Worker else THRESHOLD * 3 523*da0073e9SAndroid Build Coastguard Worker if should_retry 524*da0073e9SAndroid Build Coastguard Worker and isinstance(test_module, ShardedTest) 525*da0073e9SAndroid Build Coastguard Worker and test_module.time is not None 526*da0073e9SAndroid Build Coastguard Worker else THRESHOLD * 3 527*da0073e9SAndroid Build Coastguard Worker if is_cpp_test 528*da0073e9SAndroid Build Coastguard Worker else None 529*da0073e9SAndroid Build Coastguard Worker ) 530*da0073e9SAndroid Build Coastguard Worker print_to_stderr(f"Executing {command} ... [{datetime.now()}]") 531*da0073e9SAndroid Build Coastguard Worker 532*da0073e9SAndroid Build Coastguard Worker with ExitStack() as stack: 533*da0073e9SAndroid Build Coastguard Worker output = None 534*da0073e9SAndroid Build Coastguard Worker if options.pipe_logs: 535*da0073e9SAndroid Build Coastguard Worker output = stack.enter_context(open(log_path, "w")) 536*da0073e9SAndroid Build Coastguard Worker 537*da0073e9SAndroid Build Coastguard Worker if should_retry: 538*da0073e9SAndroid Build Coastguard Worker ret_code, was_rerun = run_test_retries( 539*da0073e9SAndroid Build Coastguard Worker command, 540*da0073e9SAndroid Build Coastguard Worker test_directory, 541*da0073e9SAndroid Build Coastguard Worker env, 542*da0073e9SAndroid Build Coastguard Worker timeout, 543*da0073e9SAndroid Build Coastguard Worker stepcurrent_key, 544*da0073e9SAndroid Build Coastguard Worker output, 545*da0073e9SAndroid Build Coastguard Worker options.continue_through_error, 546*da0073e9SAndroid Build Coastguard Worker ) 547*da0073e9SAndroid Build Coastguard Worker else: 548*da0073e9SAndroid Build Coastguard Worker command.extend([f"--sc={stepcurrent_key}", "--print-items"]) 549*da0073e9SAndroid Build Coastguard Worker ret_code, was_rerun = retry_shell( 550*da0073e9SAndroid Build Coastguard Worker command, 551*da0073e9SAndroid Build Coastguard Worker test_directory, 552*da0073e9SAndroid Build Coastguard Worker stdout=output, 553*da0073e9SAndroid Build Coastguard Worker stderr=output, 554*da0073e9SAndroid Build Coastguard Worker env=env, 555*da0073e9SAndroid Build Coastguard Worker timeout=timeout, 556*da0073e9SAndroid Build Coastguard Worker retries=0, 557*da0073e9SAndroid Build Coastguard Worker ) 558*da0073e9SAndroid Build Coastguard Worker 559*da0073e9SAndroid Build Coastguard Worker # Pytest return code 5 means no test is collected. Exit code 4 is 560*da0073e9SAndroid Build Coastguard Worker # returned when the binary is not a C++ test executable, but 4 can 561*da0073e9SAndroid Build Coastguard Worker # also be returned if the file fails before running any tests. All 562*da0073e9SAndroid Build Coastguard Worker # binary files under build/bin that are not C++ test at the time of 563*da0073e9SAndroid Build Coastguard Worker # this writing have been excluded and new ones should be added to 564*da0073e9SAndroid Build Coastguard Worker # the list of exclusions in tools/testing/discover_tests.py 565*da0073e9SAndroid Build Coastguard Worker ret_code = 0 if ret_code == 5 else ret_code 566*da0073e9SAndroid Build Coastguard Worker 567*da0073e9SAndroid Build Coastguard Worker if options.pipe_logs and print_log: 568*da0073e9SAndroid Build Coastguard Worker handle_log_file( 569*da0073e9SAndroid Build Coastguard Worker test_module, log_path, failed=(ret_code != 0), was_rerun=was_rerun 570*da0073e9SAndroid Build Coastguard Worker ) 571*da0073e9SAndroid Build Coastguard Worker return ret_code 572*da0073e9SAndroid Build Coastguard Worker 573*da0073e9SAndroid Build Coastguard Worker 574*da0073e9SAndroid Build Coastguard Workerdef try_set_cpp_stack_traces(env, command, set=True): 575*da0073e9SAndroid Build Coastguard Worker # Print full c++ stack traces during retries 576*da0073e9SAndroid Build Coastguard Worker env = env or {} 577*da0073e9SAndroid Build Coastguard Worker env["TORCH_SHOW_CPP_STACKTRACES"] = "1" if set else "0" 578*da0073e9SAndroid Build Coastguard Worker return env 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Worker 581*da0073e9SAndroid Build Coastguard Workerdef run_test_retries( 582*da0073e9SAndroid Build Coastguard Worker command, 583*da0073e9SAndroid Build Coastguard Worker test_directory, 584*da0073e9SAndroid Build Coastguard Worker env, 585*da0073e9SAndroid Build Coastguard Worker timeout, 586*da0073e9SAndroid Build Coastguard Worker stepcurrent_key, 587*da0073e9SAndroid Build Coastguard Worker output, 588*da0073e9SAndroid Build Coastguard Worker continue_through_error, 589*da0073e9SAndroid Build Coastguard Worker): 590*da0073e9SAndroid Build Coastguard Worker # Run the test with -x to stop at first failure. Rerun the test by itself. 591*da0073e9SAndroid Build Coastguard Worker # If it succeeds, move on to the rest of the tests in a new process. If it 592*da0073e9SAndroid Build Coastguard Worker # still fails, see below 593*da0073e9SAndroid Build Coastguard Worker # 594*da0073e9SAndroid Build Coastguard Worker # If continue through error is not set, then we fail fast. 595*da0073e9SAndroid Build Coastguard Worker # 596*da0073e9SAndroid Build Coastguard Worker # If continue through error is set, then we skip that test, and keep going. 597*da0073e9SAndroid Build Coastguard Worker # Basically if the same test fails 3 times in a row, skip the test on the 598*da0073e9SAndroid Build Coastguard Worker # next run, but still fail in the end. I take advantage of the value saved 599*da0073e9SAndroid Build Coastguard Worker # in stepcurrent to keep track of the most recently run test (which is the 600*da0073e9SAndroid Build Coastguard Worker # one that failed if there was a failure). 601*da0073e9SAndroid Build Coastguard Worker 602*da0073e9SAndroid Build Coastguard Worker def print_to_file(s): 603*da0073e9SAndroid Build Coastguard Worker print(s, file=output, flush=True) 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker num_failures = defaultdict(int) 606*da0073e9SAndroid Build Coastguard Worker 607*da0073e9SAndroid Build Coastguard Worker print_items = ["--print-items"] 608*da0073e9SAndroid Build Coastguard Worker sc_command = f"--sc={stepcurrent_key}" 609*da0073e9SAndroid Build Coastguard Worker while True: 610*da0073e9SAndroid Build Coastguard Worker ret_code, _ = retry_shell( 611*da0073e9SAndroid Build Coastguard Worker command + [sc_command] + print_items, 612*da0073e9SAndroid Build Coastguard Worker test_directory, 613*da0073e9SAndroid Build Coastguard Worker stdout=output, 614*da0073e9SAndroid Build Coastguard Worker stderr=output, 615*da0073e9SAndroid Build Coastguard Worker env=env, 616*da0073e9SAndroid Build Coastguard Worker timeout=timeout, 617*da0073e9SAndroid Build Coastguard Worker retries=0, # no retries here, we do it ourselves, this is because it handles timeout exceptions well 618*da0073e9SAndroid Build Coastguard Worker ) 619*da0073e9SAndroid Build Coastguard Worker ret_code = 0 if ret_code == 5 else ret_code 620*da0073e9SAndroid Build Coastguard Worker if ret_code == 0 and not sc_command.startswith("--rs="): 621*da0073e9SAndroid Build Coastguard Worker break # Got to the end of the test suite successfully 622*da0073e9SAndroid Build Coastguard Worker signal_name = f" ({SIGNALS_TO_NAMES_DICT[-ret_code]})" if ret_code < 0 else "" 623*da0073e9SAndroid Build Coastguard Worker print_to_file(f"Got exit code {ret_code}{signal_name}") 624*da0073e9SAndroid Build Coastguard Worker 625*da0073e9SAndroid Build Coastguard Worker # Read what just failed/ran 626*da0073e9SAndroid Build Coastguard Worker try: 627*da0073e9SAndroid Build Coastguard Worker with open( 628*da0073e9SAndroid Build Coastguard Worker REPO_ROOT / ".pytest_cache/v/cache/stepcurrent" / stepcurrent_key 629*da0073e9SAndroid Build Coastguard Worker ) as f: 630*da0073e9SAndroid Build Coastguard Worker current_failure = f.read() 631*da0073e9SAndroid Build Coastguard Worker except FileNotFoundError: 632*da0073e9SAndroid Build Coastguard Worker print_to_file( 633*da0073e9SAndroid Build Coastguard Worker "No stepcurrent file found. Either pytest didn't get to run (e.g. import error)" 634*da0073e9SAndroid Build Coastguard Worker + " or file got deleted (contact dev infra)" 635*da0073e9SAndroid Build Coastguard Worker ) 636*da0073e9SAndroid Build Coastguard Worker break 637*da0073e9SAndroid Build Coastguard Worker 638*da0073e9SAndroid Build Coastguard Worker env = try_set_cpp_stack_traces(env, command, set=False) 639*da0073e9SAndroid Build Coastguard Worker if ret_code != 0: 640*da0073e9SAndroid Build Coastguard Worker num_failures[current_failure] += 1 641*da0073e9SAndroid Build Coastguard Worker 642*da0073e9SAndroid Build Coastguard Worker if ret_code == 0: 643*da0073e9SAndroid Build Coastguard Worker # Rerunning the previously failing test succeeded, so now we can 644*da0073e9SAndroid Build Coastguard Worker # skip it and move on 645*da0073e9SAndroid Build Coastguard Worker sc_command = f"--scs={stepcurrent_key}" 646*da0073e9SAndroid Build Coastguard Worker print_to_file( 647*da0073e9SAndroid Build Coastguard Worker "Test succeeeded in new process, continuing with the rest of the tests" 648*da0073e9SAndroid Build Coastguard Worker ) 649*da0073e9SAndroid Build Coastguard Worker elif num_failures[current_failure] >= 3: 650*da0073e9SAndroid Build Coastguard Worker if not continue_through_error: 651*da0073e9SAndroid Build Coastguard Worker print_to_file("Stopping at first consistent failure") 652*da0073e9SAndroid Build Coastguard Worker break 653*da0073e9SAndroid Build Coastguard Worker sc_command = f"--scs={stepcurrent_key}" 654*da0073e9SAndroid Build Coastguard Worker print_to_file( 655*da0073e9SAndroid Build Coastguard Worker "Test failed consistently, " 656*da0073e9SAndroid Build Coastguard Worker "continuing with the rest of the tests due to continue-through-error being set" 657*da0073e9SAndroid Build Coastguard Worker ) 658*da0073e9SAndroid Build Coastguard Worker else: 659*da0073e9SAndroid Build Coastguard Worker env = try_set_cpp_stack_traces(env, command, set=True) 660*da0073e9SAndroid Build Coastguard Worker sc_command = f"--rs={stepcurrent_key}" 661*da0073e9SAndroid Build Coastguard Worker print_to_file("Retrying single test...") 662*da0073e9SAndroid Build Coastguard Worker print_items = [] # do not continue printing them, massive waste of space 663*da0073e9SAndroid Build Coastguard Worker 664*da0073e9SAndroid Build Coastguard Worker consistent_failures = [x[1:-1] for x in num_failures.keys() if num_failures[x] >= 3] 665*da0073e9SAndroid Build Coastguard Worker flaky_failures = [x[1:-1] for x in num_failures.keys() if 0 < num_failures[x] < 3] 666*da0073e9SAndroid Build Coastguard Worker if len(flaky_failures) > 0: 667*da0073e9SAndroid Build Coastguard Worker print_to_file( 668*da0073e9SAndroid Build Coastguard Worker "The following tests failed and then succeeded when run in a new process" 669*da0073e9SAndroid Build Coastguard Worker + f"{flaky_failures}", 670*da0073e9SAndroid Build Coastguard Worker ) 671*da0073e9SAndroid Build Coastguard Worker if len(consistent_failures) > 0: 672*da0073e9SAndroid Build Coastguard Worker print_to_file(f"The following tests failed consistently: {consistent_failures}") 673*da0073e9SAndroid Build Coastguard Worker return 1, True 674*da0073e9SAndroid Build Coastguard Worker return ret_code, any(x > 0 for x in num_failures.values()) 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Worker 677*da0073e9SAndroid Build Coastguard Workerdef run_test_with_subprocess(test_module, test_directory, options): 678*da0073e9SAndroid Build Coastguard Worker return run_test( 679*da0073e9SAndroid Build Coastguard Worker test_module, test_directory, options, extra_unittest_args=["--subprocess"] 680*da0073e9SAndroid Build Coastguard Worker ) 681*da0073e9SAndroid Build Coastguard Worker 682*da0073e9SAndroid Build Coastguard Worker 683*da0073e9SAndroid Build Coastguard Workerdef _test_cpp_extensions_aot(test_directory, options, use_ninja): 684*da0073e9SAndroid Build Coastguard Worker if use_ninja: 685*da0073e9SAndroid Build Coastguard Worker try: 686*da0073e9SAndroid Build Coastguard Worker from torch.utils import cpp_extension 687*da0073e9SAndroid Build Coastguard Worker 688*da0073e9SAndroid Build Coastguard Worker cpp_extension.verify_ninja_availability() 689*da0073e9SAndroid Build Coastguard Worker except RuntimeError: 690*da0073e9SAndroid Build Coastguard Worker print_to_stderr(CPP_EXTENSIONS_ERROR) 691*da0073e9SAndroid Build Coastguard Worker return 1 692*da0073e9SAndroid Build Coastguard Worker 693*da0073e9SAndroid Build Coastguard Worker # Wipe the build folder, if it exists already 694*da0073e9SAndroid Build Coastguard Worker cpp_extensions_test_dir = os.path.join(test_directory, "cpp_extensions") 695*da0073e9SAndroid Build Coastguard Worker cpp_extensions_test_build_dir = os.path.join(cpp_extensions_test_dir, "build") 696*da0073e9SAndroid Build Coastguard Worker if os.path.exists(cpp_extensions_test_build_dir): 697*da0073e9SAndroid Build Coastguard Worker shutil.rmtree(cpp_extensions_test_build_dir) 698*da0073e9SAndroid Build Coastguard Worker 699*da0073e9SAndroid Build Coastguard Worker # Build the test cpp extensions modules 700*da0073e9SAndroid Build Coastguard Worker shell_env = os.environ.copy() 701*da0073e9SAndroid Build Coastguard Worker shell_env["USE_NINJA"] = str(1 if use_ninja else 0) 702*da0073e9SAndroid Build Coastguard Worker cmd = [sys.executable, "setup.py", "install", "--root", "./install"] 703*da0073e9SAndroid Build Coastguard Worker return_code = shell(cmd, cwd=cpp_extensions_test_dir, env=shell_env) 704*da0073e9SAndroid Build Coastguard Worker if return_code != 0: 705*da0073e9SAndroid Build Coastguard Worker return return_code 706*da0073e9SAndroid Build Coastguard Worker if sys.platform != "win32": 707*da0073e9SAndroid Build Coastguard Worker return_code = shell( 708*da0073e9SAndroid Build Coastguard Worker cmd, 709*da0073e9SAndroid Build Coastguard Worker cwd=os.path.join(cpp_extensions_test_dir, "no_python_abi_suffix_test"), 710*da0073e9SAndroid Build Coastguard Worker env=shell_env, 711*da0073e9SAndroid Build Coastguard Worker ) 712*da0073e9SAndroid Build Coastguard Worker if return_code != 0: 713*da0073e9SAndroid Build Coastguard Worker return return_code 714*da0073e9SAndroid Build Coastguard Worker 715*da0073e9SAndroid Build Coastguard Worker # "install" the test modules and run tests 716*da0073e9SAndroid Build Coastguard Worker python_path = os.environ.get("PYTHONPATH", "") 717*da0073e9SAndroid Build Coastguard Worker from shutil import copyfile 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Worker os.environ["USE_NINJA"] = shell_env["USE_NINJA"] 720*da0073e9SAndroid Build Coastguard Worker test_module = "test_cpp_extensions_aot" + ("_ninja" if use_ninja else "_no_ninja") 721*da0073e9SAndroid Build Coastguard Worker copyfile( 722*da0073e9SAndroid Build Coastguard Worker test_directory + "/test_cpp_extensions_aot.py", 723*da0073e9SAndroid Build Coastguard Worker test_directory + "/" + test_module + ".py", 724*da0073e9SAndroid Build Coastguard Worker ) 725*da0073e9SAndroid Build Coastguard Worker try: 726*da0073e9SAndroid Build Coastguard Worker cpp_extensions = os.path.join(test_directory, "cpp_extensions") 727*da0073e9SAndroid Build Coastguard Worker install_directory = "" 728*da0073e9SAndroid Build Coastguard Worker # install directory is the one that is named site-packages 729*da0073e9SAndroid Build Coastguard Worker for root, directories, _ in os.walk(os.path.join(cpp_extensions, "install")): 730*da0073e9SAndroid Build Coastguard Worker for directory in directories: 731*da0073e9SAndroid Build Coastguard Worker if "-packages" in directory: 732*da0073e9SAndroid Build Coastguard Worker install_directory = os.path.join(root, directory) 733*da0073e9SAndroid Build Coastguard Worker 734*da0073e9SAndroid Build Coastguard Worker assert install_directory, "install_directory must not be empty" 735*da0073e9SAndroid Build Coastguard Worker os.environ["PYTHONPATH"] = os.pathsep.join([install_directory, python_path]) 736*da0073e9SAndroid Build Coastguard Worker return run_test(ShardedTest(test_module, 1, 1), test_directory, options) 737*da0073e9SAndroid Build Coastguard Worker finally: 738*da0073e9SAndroid Build Coastguard Worker os.environ["PYTHONPATH"] = python_path 739*da0073e9SAndroid Build Coastguard Worker if os.path.exists(test_directory + "/" + test_module + ".py"): 740*da0073e9SAndroid Build Coastguard Worker os.remove(test_directory + "/" + test_module + ".py") 741*da0073e9SAndroid Build Coastguard Worker os.environ.pop("USE_NINJA") 742*da0073e9SAndroid Build Coastguard Worker 743*da0073e9SAndroid Build Coastguard Worker 744*da0073e9SAndroid Build Coastguard Workerdef test_cpp_extensions_aot_ninja(test_module, test_directory, options): 745*da0073e9SAndroid Build Coastguard Worker return _test_cpp_extensions_aot(test_directory, options, use_ninja=True) 746*da0073e9SAndroid Build Coastguard Worker 747*da0073e9SAndroid Build Coastguard Worker 748*da0073e9SAndroid Build Coastguard Workerdef test_cpp_extensions_aot_no_ninja(test_module, test_directory, options): 749*da0073e9SAndroid Build Coastguard Worker return _test_cpp_extensions_aot(test_directory, options, use_ninja=False) 750*da0073e9SAndroid Build Coastguard Worker 751*da0073e9SAndroid Build Coastguard Worker 752*da0073e9SAndroid Build Coastguard Workerdef test_autoload_enable(test_module, test_directory, options): 753*da0073e9SAndroid Build Coastguard Worker return _test_autoload(test_directory, options, enable=True) 754*da0073e9SAndroid Build Coastguard Worker 755*da0073e9SAndroid Build Coastguard Worker 756*da0073e9SAndroid Build Coastguard Workerdef test_autoload_disable(test_module, test_directory, options): 757*da0073e9SAndroid Build Coastguard Worker return _test_autoload(test_directory, options, enable=False) 758*da0073e9SAndroid Build Coastguard Worker 759*da0073e9SAndroid Build Coastguard Worker 760*da0073e9SAndroid Build Coastguard Workerdef _test_autoload(test_directory, options, enable=True): 761*da0073e9SAndroid Build Coastguard Worker # Wipe the build folder, if it exists already 762*da0073e9SAndroid Build Coastguard Worker cpp_extensions_test_dir = os.path.join(test_directory, "cpp_extensions") 763*da0073e9SAndroid Build Coastguard Worker cpp_extensions_test_build_dir = os.path.join(cpp_extensions_test_dir, "build") 764*da0073e9SAndroid Build Coastguard Worker if os.path.exists(cpp_extensions_test_build_dir): 765*da0073e9SAndroid Build Coastguard Worker shutil.rmtree(cpp_extensions_test_build_dir) 766*da0073e9SAndroid Build Coastguard Worker 767*da0073e9SAndroid Build Coastguard Worker # Build the test cpp extensions modules 768*da0073e9SAndroid Build Coastguard Worker cmd = [sys.executable, "setup.py", "install", "--root", "./install"] 769*da0073e9SAndroid Build Coastguard Worker return_code = shell(cmd, cwd=cpp_extensions_test_dir, env=os.environ) 770*da0073e9SAndroid Build Coastguard Worker if return_code != 0: 771*da0073e9SAndroid Build Coastguard Worker return return_code 772*da0073e9SAndroid Build Coastguard Worker 773*da0073e9SAndroid Build Coastguard Worker # "install" the test modules and run tests 774*da0073e9SAndroid Build Coastguard Worker python_path = os.environ.get("PYTHONPATH", "") 775*da0073e9SAndroid Build Coastguard Worker 776*da0073e9SAndroid Build Coastguard Worker try: 777*da0073e9SAndroid Build Coastguard Worker cpp_extensions = os.path.join(test_directory, "cpp_extensions") 778*da0073e9SAndroid Build Coastguard Worker install_directory = "" 779*da0073e9SAndroid Build Coastguard Worker # install directory is the one that is named site-packages 780*da0073e9SAndroid Build Coastguard Worker for root, directories, _ in os.walk(os.path.join(cpp_extensions, "install")): 781*da0073e9SAndroid Build Coastguard Worker for directory in directories: 782*da0073e9SAndroid Build Coastguard Worker if "-packages" in directory: 783*da0073e9SAndroid Build Coastguard Worker install_directory = os.path.join(root, directory) 784*da0073e9SAndroid Build Coastguard Worker 785*da0073e9SAndroid Build Coastguard Worker assert install_directory, "install_directory must not be empty" 786*da0073e9SAndroid Build Coastguard Worker os.environ["PYTHONPATH"] = os.pathsep.join([install_directory, python_path]) 787*da0073e9SAndroid Build Coastguard Worker os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = str(int(enable)) 788*da0073e9SAndroid Build Coastguard Worker 789*da0073e9SAndroid Build Coastguard Worker cmd = [sys.executable, "test_autoload.py"] 790*da0073e9SAndroid Build Coastguard Worker return_code = shell(cmd, cwd=test_directory, env=os.environ) 791*da0073e9SAndroid Build Coastguard Worker return return_code 792*da0073e9SAndroid Build Coastguard Worker finally: 793*da0073e9SAndroid Build Coastguard Worker os.environ["PYTHONPATH"] = python_path 794*da0073e9SAndroid Build Coastguard Worker os.environ.pop("TORCH_DEVICE_BACKEND_AUTOLOAD") 795*da0073e9SAndroid Build Coastguard Worker 796*da0073e9SAndroid Build Coastguard Worker 797*da0073e9SAndroid Build Coastguard Workerdef test_distributed(test_module, test_directory, options): 798*da0073e9SAndroid Build Coastguard Worker # MPI tests are broken with Python-3.9 799*da0073e9SAndroid Build Coastguard Worker mpi_available = subprocess.call( 800*da0073e9SAndroid Build Coastguard Worker "command -v mpiexec", shell=True 801*da0073e9SAndroid Build Coastguard Worker ) == 0 and sys.version_info < (3, 9) 802*da0073e9SAndroid Build Coastguard Worker if options.verbose and not mpi_available: 803*da0073e9SAndroid Build Coastguard Worker print_to_stderr("MPI not available -- MPI backend tests will be skipped") 804*da0073e9SAndroid Build Coastguard Worker 805*da0073e9SAndroid Build Coastguard Worker config = DISTRIBUTED_TESTS_CONFIG 806*da0073e9SAndroid Build Coastguard Worker for backend, env_vars in config.items(): 807*da0073e9SAndroid Build Coastguard Worker if sys.platform == "win32" and backend != "gloo": 808*da0073e9SAndroid Build Coastguard Worker continue 809*da0073e9SAndroid Build Coastguard Worker if backend == "mpi" and not mpi_available: 810*da0073e9SAndroid Build Coastguard Worker continue 811*da0073e9SAndroid Build Coastguard Worker for with_init_file in {True, False}: 812*da0073e9SAndroid Build Coastguard Worker if sys.platform == "win32" and not with_init_file: 813*da0073e9SAndroid Build Coastguard Worker continue 814*da0073e9SAndroid Build Coastguard Worker tmp_dir = tempfile.mkdtemp() 815*da0073e9SAndroid Build Coastguard Worker if options.verbose: 816*da0073e9SAndroid Build Coastguard Worker init_str = "with {} init_method" 817*da0073e9SAndroid Build Coastguard Worker with_init = init_str.format("file" if with_init_file else "env") 818*da0073e9SAndroid Build Coastguard Worker print_to_stderr( 819*da0073e9SAndroid Build Coastguard Worker f"Running distributed tests for the {backend} backend {with_init}" 820*da0073e9SAndroid Build Coastguard Worker ) 821*da0073e9SAndroid Build Coastguard Worker old_environ = dict(os.environ) 822*da0073e9SAndroid Build Coastguard Worker os.environ["TEMP_DIR"] = tmp_dir 823*da0073e9SAndroid Build Coastguard Worker os.environ["BACKEND"] = backend 824*da0073e9SAndroid Build Coastguard Worker os.environ.update(env_vars) 825*da0073e9SAndroid Build Coastguard Worker try: 826*da0073e9SAndroid Build Coastguard Worker os.mkdir(os.path.join(tmp_dir, "barrier")) 827*da0073e9SAndroid Build Coastguard Worker os.mkdir(os.path.join(tmp_dir, "test_dir")) 828*da0073e9SAndroid Build Coastguard Worker if backend == "mpi": 829*da0073e9SAndroid Build Coastguard Worker # test mpiexec for --noprefix option 830*da0073e9SAndroid Build Coastguard Worker with open(os.devnull, "w") as devnull: 831*da0073e9SAndroid Build Coastguard Worker allowrunasroot_opt = ( 832*da0073e9SAndroid Build Coastguard Worker "--allow-run-as-root" 833*da0073e9SAndroid Build Coastguard Worker if subprocess.call( 834*da0073e9SAndroid Build Coastguard Worker 'mpiexec --allow-run-as-root -n 1 bash -c ""', 835*da0073e9SAndroid Build Coastguard Worker shell=True, 836*da0073e9SAndroid Build Coastguard Worker stdout=devnull, 837*da0073e9SAndroid Build Coastguard Worker stderr=subprocess.STDOUT, 838*da0073e9SAndroid Build Coastguard Worker ) 839*da0073e9SAndroid Build Coastguard Worker == 0 840*da0073e9SAndroid Build Coastguard Worker else "" 841*da0073e9SAndroid Build Coastguard Worker ) 842*da0073e9SAndroid Build Coastguard Worker noprefix_opt = ( 843*da0073e9SAndroid Build Coastguard Worker "--noprefix" 844*da0073e9SAndroid Build Coastguard Worker if subprocess.call( 845*da0073e9SAndroid Build Coastguard Worker f'mpiexec {allowrunasroot_opt} -n 1 --noprefix bash -c ""', 846*da0073e9SAndroid Build Coastguard Worker shell=True, 847*da0073e9SAndroid Build Coastguard Worker stdout=devnull, 848*da0073e9SAndroid Build Coastguard Worker stderr=subprocess.STDOUT, 849*da0073e9SAndroid Build Coastguard Worker ) 850*da0073e9SAndroid Build Coastguard Worker == 0 851*da0073e9SAndroid Build Coastguard Worker else "" 852*da0073e9SAndroid Build Coastguard Worker ) 853*da0073e9SAndroid Build Coastguard Worker 854*da0073e9SAndroid Build Coastguard Worker mpiexec = ["mpiexec", "-n", "3", noprefix_opt, allowrunasroot_opt] 855*da0073e9SAndroid Build Coastguard Worker 856*da0073e9SAndroid Build Coastguard Worker return_code = run_test( 857*da0073e9SAndroid Build Coastguard Worker test_module, test_directory, options, launcher_cmd=mpiexec 858*da0073e9SAndroid Build Coastguard Worker ) 859*da0073e9SAndroid Build Coastguard Worker else: 860*da0073e9SAndroid Build Coastguard Worker return_code = run_test( 861*da0073e9SAndroid Build Coastguard Worker test_module, 862*da0073e9SAndroid Build Coastguard Worker test_directory, 863*da0073e9SAndroid Build Coastguard Worker options, 864*da0073e9SAndroid Build Coastguard Worker extra_unittest_args=["--subprocess"], 865*da0073e9SAndroid Build Coastguard Worker ) 866*da0073e9SAndroid Build Coastguard Worker if return_code != 0: 867*da0073e9SAndroid Build Coastguard Worker return return_code 868*da0073e9SAndroid Build Coastguard Worker finally: 869*da0073e9SAndroid Build Coastguard Worker shutil.rmtree(tmp_dir) 870*da0073e9SAndroid Build Coastguard Worker os.environ.clear() 871*da0073e9SAndroid Build Coastguard Worker os.environ.update(old_environ) 872*da0073e9SAndroid Build Coastguard Worker return 0 873*da0073e9SAndroid Build Coastguard Worker 874*da0073e9SAndroid Build Coastguard Worker 875*da0073e9SAndroid Build Coastguard Workerdef run_doctests(test_module, test_directory, options): 876*da0073e9SAndroid Build Coastguard Worker """ 877*da0073e9SAndroid Build Coastguard Worker Assumes the incoming test module is called doctest, and simply executes the 878*da0073e9SAndroid Build Coastguard Worker xdoctest runner on the torch library itself. 879*da0073e9SAndroid Build Coastguard Worker """ 880*da0073e9SAndroid Build Coastguard Worker import xdoctest 881*da0073e9SAndroid Build Coastguard Worker 882*da0073e9SAndroid Build Coastguard Worker pkgpath = Path(torch.__file__).parent 883*da0073e9SAndroid Build Coastguard Worker 884*da0073e9SAndroid Build Coastguard Worker exclude_module_list = ["torch._vendor.*"] 885*da0073e9SAndroid Build Coastguard Worker enabled = { 886*da0073e9SAndroid Build Coastguard Worker # TODO: expose these options to the user 887*da0073e9SAndroid Build Coastguard Worker # For now disable all feature-conditional tests 888*da0073e9SAndroid Build Coastguard Worker # 'lapack': 'auto', 889*da0073e9SAndroid Build Coastguard Worker # 'cuda': 'auto', 890*da0073e9SAndroid Build Coastguard Worker # 'cuda1': 'auto', 891*da0073e9SAndroid Build Coastguard Worker # 'qengine': 'auto', 892*da0073e9SAndroid Build Coastguard Worker "lapack": 0, 893*da0073e9SAndroid Build Coastguard Worker "cuda": 0, 894*da0073e9SAndroid Build Coastguard Worker "cuda1": 0, 895*da0073e9SAndroid Build Coastguard Worker "qengine": 0, 896*da0073e9SAndroid Build Coastguard Worker "autograd_profiler": 0, 897*da0073e9SAndroid Build Coastguard Worker "cpp_ext": 0, 898*da0073e9SAndroid Build Coastguard Worker "monitor": 0, 899*da0073e9SAndroid Build Coastguard Worker "onnx": "auto", 900*da0073e9SAndroid Build Coastguard Worker } 901*da0073e9SAndroid Build Coastguard Worker 902*da0073e9SAndroid Build Coastguard Worker # Resolve "auto" based on a test to determine if the feature is available. 903*da0073e9SAndroid Build Coastguard Worker if enabled["cuda"] == "auto" and torch.cuda.is_available(): 904*da0073e9SAndroid Build Coastguard Worker enabled["cuda"] = True 905*da0073e9SAndroid Build Coastguard Worker 906*da0073e9SAndroid Build Coastguard Worker if ( 907*da0073e9SAndroid Build Coastguard Worker enabled["cuda1"] == "auto" 908*da0073e9SAndroid Build Coastguard Worker and torch.cuda.is_available() 909*da0073e9SAndroid Build Coastguard Worker and torch.cuda.device_count() > 1 910*da0073e9SAndroid Build Coastguard Worker ): 911*da0073e9SAndroid Build Coastguard Worker enabled["cuda1"] = True 912*da0073e9SAndroid Build Coastguard Worker 913*da0073e9SAndroid Build Coastguard Worker if enabled["lapack"] == "auto" and torch._C.has_lapack: 914*da0073e9SAndroid Build Coastguard Worker enabled["lapack"] = True 915*da0073e9SAndroid Build Coastguard Worker 916*da0073e9SAndroid Build Coastguard Worker if enabled["qengine"] == "auto": 917*da0073e9SAndroid Build Coastguard Worker try: 918*da0073e9SAndroid Build Coastguard Worker # Is there a better check if quantization is enabled? 919*da0073e9SAndroid Build Coastguard Worker import torch.ao.nn.quantized as nnq # NOQA: F401 920*da0073e9SAndroid Build Coastguard Worker 921*da0073e9SAndroid Build Coastguard Worker torch.backends.quantized.engine = "qnnpack" 922*da0073e9SAndroid Build Coastguard Worker torch.backends.quantized.engine = "fbgemm" 923*da0073e9SAndroid Build Coastguard Worker except (ImportError, RuntimeError): 924*da0073e9SAndroid Build Coastguard Worker ... 925*da0073e9SAndroid Build Coastguard Worker else: 926*da0073e9SAndroid Build Coastguard Worker enabled["qengine"] = True 927*da0073e9SAndroid Build Coastguard Worker 928*da0073e9SAndroid Build Coastguard Worker if enabled["onnx"] == "auto": 929*da0073e9SAndroid Build Coastguard Worker try: 930*da0073e9SAndroid Build Coastguard Worker import onnx # NOQA: F401 931*da0073e9SAndroid Build Coastguard Worker import onnxruntime # NOQA: F401 932*da0073e9SAndroid Build Coastguard Worker import onnxscript # NOQA: F401 933*da0073e9SAndroid Build Coastguard Worker except ImportError: 934*da0073e9SAndroid Build Coastguard Worker exclude_module_list.append("torch.onnx.*") 935*da0073e9SAndroid Build Coastguard Worker enabled["onnx"] = False 936*da0073e9SAndroid Build Coastguard Worker else: 937*da0073e9SAndroid Build Coastguard Worker enabled["onnx"] = True 938*da0073e9SAndroid Build Coastguard Worker 939*da0073e9SAndroid Build Coastguard Worker # Set doctest environment variables 940*da0073e9SAndroid Build Coastguard Worker if enabled["cuda"]: 941*da0073e9SAndroid Build Coastguard Worker os.environ["TORCH_DOCTEST_CUDA"] = "1" 942*da0073e9SAndroid Build Coastguard Worker 943*da0073e9SAndroid Build Coastguard Worker if enabled["cuda1"]: 944*da0073e9SAndroid Build Coastguard Worker os.environ["TORCH_DOCTEST_CUDA1"] = "1" 945*da0073e9SAndroid Build Coastguard Worker 946*da0073e9SAndroid Build Coastguard Worker if enabled["lapack"]: 947*da0073e9SAndroid Build Coastguard Worker os.environ["TORCH_DOCTEST_LAPACK"] = "1" 948*da0073e9SAndroid Build Coastguard Worker 949*da0073e9SAndroid Build Coastguard Worker if enabled["qengine"]: 950*da0073e9SAndroid Build Coastguard Worker os.environ["TORCH_DOCTEST_QENGINE"] = "1" 951*da0073e9SAndroid Build Coastguard Worker 952*da0073e9SAndroid Build Coastguard Worker if enabled["autograd_profiler"]: 953*da0073e9SAndroid Build Coastguard Worker os.environ["TORCH_DOCTEST_AUTOGRAD_PROFILER"] = "1" 954*da0073e9SAndroid Build Coastguard Worker 955*da0073e9SAndroid Build Coastguard Worker if enabled["cpp_ext"]: 956*da0073e9SAndroid Build Coastguard Worker os.environ["TORCH_DOCTEST_CPP_EXT"] = "1" 957*da0073e9SAndroid Build Coastguard Worker 958*da0073e9SAndroid Build Coastguard Worker if enabled["monitor"]: 959*da0073e9SAndroid Build Coastguard Worker os.environ["TORCH_DOCTEST_MONITOR"] = "1" 960*da0073e9SAndroid Build Coastguard Worker 961*da0073e9SAndroid Build Coastguard Worker if enabled["onnx"]: 962*da0073e9SAndroid Build Coastguard Worker os.environ["TORCH_DOCTEST_ONNX"] = "1" 963*da0073e9SAndroid Build Coastguard Worker 964*da0073e9SAndroid Build Coastguard Worker if 0: 965*da0073e9SAndroid Build Coastguard Worker # TODO: could try to enable some of these 966*da0073e9SAndroid Build Coastguard Worker os.environ["TORCH_DOCTEST_QUANTIZED_DYNAMIC"] = "1" 967*da0073e9SAndroid Build Coastguard Worker os.environ["TORCH_DOCTEST_ANOMALY"] = "1" 968*da0073e9SAndroid Build Coastguard Worker os.environ["TORCH_DOCTEST_AUTOGRAD"] = "1" 969*da0073e9SAndroid Build Coastguard Worker os.environ["TORCH_DOCTEST_HUB"] = "1" 970*da0073e9SAndroid Build Coastguard Worker os.environ["TORCH_DOCTEST_DATALOADER"] = "1" 971*da0073e9SAndroid Build Coastguard Worker os.environ["TORCH_DOCTEST_FUTURES"] = "1" 972*da0073e9SAndroid Build Coastguard Worker 973*da0073e9SAndroid Build Coastguard Worker pkgpath = os.path.dirname(torch.__file__) 974*da0073e9SAndroid Build Coastguard Worker 975*da0073e9SAndroid Build Coastguard Worker xdoctest_config = { 976*da0073e9SAndroid Build Coastguard Worker "global_exec": r"\n".join( 977*da0073e9SAndroid Build Coastguard Worker [ 978*da0073e9SAndroid Build Coastguard Worker "from torch import nn", 979*da0073e9SAndroid Build Coastguard Worker "import torch.nn.functional as F", 980*da0073e9SAndroid Build Coastguard Worker "import torch", 981*da0073e9SAndroid Build Coastguard Worker ] 982*da0073e9SAndroid Build Coastguard Worker ), 983*da0073e9SAndroid Build Coastguard Worker "analysis": "static", # set to "auto" to test doctests in compiled modules 984*da0073e9SAndroid Build Coastguard Worker "style": "google", 985*da0073e9SAndroid Build Coastguard Worker "options": "+IGNORE_WHITESPACE", 986*da0073e9SAndroid Build Coastguard Worker } 987*da0073e9SAndroid Build Coastguard Worker xdoctest_verbose = max(1, options.verbose) 988*da0073e9SAndroid Build Coastguard Worker run_summary = xdoctest.runner.doctest_module( 989*da0073e9SAndroid Build Coastguard Worker os.fspath(pkgpath), 990*da0073e9SAndroid Build Coastguard Worker config=xdoctest_config, 991*da0073e9SAndroid Build Coastguard Worker verbose=xdoctest_verbose, 992*da0073e9SAndroid Build Coastguard Worker command=options.xdoctest_command, 993*da0073e9SAndroid Build Coastguard Worker argv=[], 994*da0073e9SAndroid Build Coastguard Worker exclude=exclude_module_list, 995*da0073e9SAndroid Build Coastguard Worker ) 996*da0073e9SAndroid Build Coastguard Worker result = 1 if run_summary.get("n_failed", 0) else 0 997*da0073e9SAndroid Build Coastguard Worker return result 998*da0073e9SAndroid Build Coastguard Worker 999*da0073e9SAndroid Build Coastguard Worker 1000*da0073e9SAndroid Build Coastguard Workerdef sanitize_file_name(file: str): 1001*da0073e9SAndroid Build Coastguard Worker return file.replace("\\", ".").replace("/", ".").replace(" ", "_") 1002*da0073e9SAndroid Build Coastguard Worker 1003*da0073e9SAndroid Build Coastguard Worker 1004*da0073e9SAndroid Build Coastguard Workerdef handle_log_file( 1005*da0073e9SAndroid Build Coastguard Worker test: ShardedTest, file_path: str, failed: bool, was_rerun: bool 1006*da0073e9SAndroid Build Coastguard Worker) -> None: 1007*da0073e9SAndroid Build Coastguard Worker test = str(test) 1008*da0073e9SAndroid Build Coastguard Worker with open(file_path, errors="ignore") as f: 1009*da0073e9SAndroid Build Coastguard Worker full_text = f.read() 1010*da0073e9SAndroid Build Coastguard Worker 1011*da0073e9SAndroid Build Coastguard Worker new_file = "test/test-reports/" + sanitize_file_name( 1012*da0073e9SAndroid Build Coastguard Worker f"{test}_{os.urandom(8).hex()}_.log" 1013*da0073e9SAndroid Build Coastguard Worker ) 1014*da0073e9SAndroid Build Coastguard Worker os.rename(file_path, REPO_ROOT / new_file) 1015*da0073e9SAndroid Build Coastguard Worker 1016*da0073e9SAndroid Build Coastguard Worker if not failed and not was_rerun and "=== RERUNS ===" not in full_text: 1017*da0073e9SAndroid Build Coastguard Worker # If success + no retries (idk how else to check for test level retries 1018*da0073e9SAndroid Build Coastguard Worker # other than reparse xml), print only what tests ran 1019*da0073e9SAndroid Build Coastguard Worker print_to_stderr( 1020*da0073e9SAndroid Build Coastguard Worker f"\n{test} was successful, full logs can be found in artifacts with path {new_file}" 1021*da0073e9SAndroid Build Coastguard Worker ) 1022*da0073e9SAndroid Build Coastguard Worker for line in full_text.splitlines(): 1023*da0073e9SAndroid Build Coastguard Worker if re.search("Running .* items in this shard:", line): 1024*da0073e9SAndroid Build Coastguard Worker print_to_stderr(line.rstrip()) 1025*da0073e9SAndroid Build Coastguard Worker print_to_stderr("") 1026*da0073e9SAndroid Build Coastguard Worker return 1027*da0073e9SAndroid Build Coastguard Worker 1028*da0073e9SAndroid Build Coastguard Worker # otherwise: print entire file 1029*da0073e9SAndroid Build Coastguard Worker print_to_stderr(f"\nPRINTING LOG FILE of {test} ({new_file})") 1030*da0073e9SAndroid Build Coastguard Worker print_to_stderr(full_text) 1031*da0073e9SAndroid Build Coastguard Worker print_to_stderr(f"FINISHED PRINTING LOG FILE of {test} ({new_file})\n") 1032*da0073e9SAndroid Build Coastguard Worker 1033*da0073e9SAndroid Build Coastguard Worker 1034*da0073e9SAndroid Build Coastguard Workerdef get_pytest_args(options, is_cpp_test=False, is_distributed_test=False): 1035*da0073e9SAndroid Build Coastguard Worker if RERUN_DISABLED_TESTS: 1036*da0073e9SAndroid Build Coastguard Worker # Distributed tests are too slow, so running them x50 will cause the jobs to timeout after 1037*da0073e9SAndroid Build Coastguard Worker # 3+ hours. So, let's opt for less number of reruns. We need at least 150 instances of the 1038*da0073e9SAndroid Build Coastguard Worker # test every 2 weeks to satisfy the Rockset query (15 x 14 = 210). The same logic applies 1039*da0073e9SAndroid Build Coastguard Worker # to ASAN, which is also slow 1040*da0073e9SAndroid Build Coastguard Worker count = 15 if is_distributed_test or TEST_WITH_ASAN else 50 1041*da0073e9SAndroid Build Coastguard Worker # When under rerun-disabled-tests mode, run the same tests multiple times to determine their 1042*da0073e9SAndroid Build Coastguard Worker # flakiness status. Default to 50 re-runs 1043*da0073e9SAndroid Build Coastguard Worker rerun_options = ["--flake-finder", f"--flake-runs={count}"] 1044*da0073e9SAndroid Build Coastguard Worker else: 1045*da0073e9SAndroid Build Coastguard Worker # When under the normal mode, retry a failed test 2 more times. -x means stop at the first 1046*da0073e9SAndroid Build Coastguard Worker # failure 1047*da0073e9SAndroid Build Coastguard Worker rerun_options = ["-x", "--reruns=2"] 1048*da0073e9SAndroid Build Coastguard Worker 1049*da0073e9SAndroid Build Coastguard Worker pytest_args = [ 1050*da0073e9SAndroid Build Coastguard Worker "-vv", 1051*da0073e9SAndroid Build Coastguard Worker "-rfEX", 1052*da0073e9SAndroid Build Coastguard Worker ] 1053*da0073e9SAndroid Build Coastguard Worker if not is_cpp_test: 1054*da0073e9SAndroid Build Coastguard Worker # C++ tests need to be run with pytest directly, not via python 1055*da0073e9SAndroid Build Coastguard Worker # We have a custom pytest shard that conflicts with the normal plugin 1056*da0073e9SAndroid Build Coastguard Worker pytest_args.extend(["-p", "no:xdist", "--use-pytest"]) 1057*da0073e9SAndroid Build Coastguard Worker else: 1058*da0073e9SAndroid Build Coastguard Worker # Use pytext-dist to run C++ tests in parallel as running them sequentially using run_test 1059*da0073e9SAndroid Build Coastguard Worker # is much slower than running them directly 1060*da0073e9SAndroid Build Coastguard Worker pytest_args.extend(["-n", str(NUM_PROCS)]) 1061*da0073e9SAndroid Build Coastguard Worker 1062*da0073e9SAndroid Build Coastguard Worker if IS_CI: 1063*da0073e9SAndroid Build Coastguard Worker # Add the option to generate XML test report here as C++ tests 1064*da0073e9SAndroid Build Coastguard Worker # won't go into common_utils 1065*da0073e9SAndroid Build Coastguard Worker test_report_path = get_report_path(pytest=True) 1066*da0073e9SAndroid Build Coastguard Worker pytest_args.extend(["--junit-xml-reruns", test_report_path]) 1067*da0073e9SAndroid Build Coastguard Worker 1068*da0073e9SAndroid Build Coastguard Worker if options.pytest_k_expr: 1069*da0073e9SAndroid Build Coastguard Worker pytest_args.extend(["-k", options.pytest_k_expr]) 1070*da0073e9SAndroid Build Coastguard Worker 1071*da0073e9SAndroid Build Coastguard Worker pytest_args.extend(rerun_options) 1072*da0073e9SAndroid Build Coastguard Worker return pytest_args 1073*da0073e9SAndroid Build Coastguard Worker 1074*da0073e9SAndroid Build Coastguard Worker 1075*da0073e9SAndroid Build Coastguard Workerdef run_ci_sanity_check(test: ShardedTest, test_directory, options): 1076*da0073e9SAndroid Build Coastguard Worker assert ( 1077*da0073e9SAndroid Build Coastguard Worker test.name == "test_ci_sanity_check_fail" 1078*da0073e9SAndroid Build Coastguard Worker ), f"This handler only works for test_ci_sanity_check_fail, got {test.name}" 1079*da0073e9SAndroid Build Coastguard Worker ret_code = run_test(test, test_directory, options, print_log=False) 1080*da0073e9SAndroid Build Coastguard Worker # This test should fail 1081*da0073e9SAndroid Build Coastguard Worker if ret_code != 1: 1082*da0073e9SAndroid Build Coastguard Worker return 1 1083*da0073e9SAndroid Build Coastguard Worker test_reports_dir = str(REPO_ROOT / "test/test-reports") 1084*da0073e9SAndroid Build Coastguard Worker # Delete the log files and xmls generated by the test 1085*da0073e9SAndroid Build Coastguard Worker for file in glob.glob(f"{test_reports_dir}/{test.name}*.log"): 1086*da0073e9SAndroid Build Coastguard Worker os.remove(file) 1087*da0073e9SAndroid Build Coastguard Worker for dirname in glob.glob(f"{test_reports_dir}/**/{test.name}"): 1088*da0073e9SAndroid Build Coastguard Worker shutil.rmtree(dirname) 1089*da0073e9SAndroid Build Coastguard Worker return 0 1090*da0073e9SAndroid Build Coastguard Worker 1091*da0073e9SAndroid Build Coastguard Worker 1092*da0073e9SAndroid Build Coastguard WorkerCUSTOM_HANDLERS = { 1093*da0073e9SAndroid Build Coastguard Worker "test_cuda_primary_ctx": run_test_with_subprocess, 1094*da0073e9SAndroid Build Coastguard Worker "test_cuda_nvml_based_avail": run_test_with_subprocess, 1095*da0073e9SAndroid Build Coastguard Worker "test_cuda_trace": run_test_with_subprocess, 1096*da0073e9SAndroid Build Coastguard Worker "test_cpp_extensions_aot_no_ninja": test_cpp_extensions_aot_no_ninja, 1097*da0073e9SAndroid Build Coastguard Worker "test_cpp_extensions_aot_ninja": test_cpp_extensions_aot_ninja, 1098*da0073e9SAndroid Build Coastguard Worker "distributed/test_distributed_spawn": test_distributed, 1099*da0073e9SAndroid Build Coastguard Worker "distributed/algorithms/quantization/test_quantization": test_distributed, 1100*da0073e9SAndroid Build Coastguard Worker "distributed/test_c10d_nccl": run_test_with_subprocess, 1101*da0073e9SAndroid Build Coastguard Worker "distributed/test_c10d_gloo": run_test_with_subprocess, 1102*da0073e9SAndroid Build Coastguard Worker "distributed/test_c10d_ucc": run_test_with_subprocess, 1103*da0073e9SAndroid Build Coastguard Worker "distributed/test_c10d_common": run_test_with_subprocess, 1104*da0073e9SAndroid Build Coastguard Worker "distributed/test_c10d_spawn_gloo": run_test_with_subprocess, 1105*da0073e9SAndroid Build Coastguard Worker "distributed/test_c10d_spawn_nccl": run_test_with_subprocess, 1106*da0073e9SAndroid Build Coastguard Worker "distributed/test_c10d_spawn_ucc": run_test_with_subprocess, 1107*da0073e9SAndroid Build Coastguard Worker "distributed/test_store": run_test_with_subprocess, 1108*da0073e9SAndroid Build Coastguard Worker "distributed/test_pg_wrapper": run_test_with_subprocess, 1109*da0073e9SAndroid Build Coastguard Worker "distributed/rpc/test_faulty_agent": run_test_with_subprocess, 1110*da0073e9SAndroid Build Coastguard Worker "distributed/rpc/test_tensorpipe_agent": run_test_with_subprocess, 1111*da0073e9SAndroid Build Coastguard Worker "distributed/rpc/test_share_memory": run_test_with_subprocess, 1112*da0073e9SAndroid Build Coastguard Worker "distributed/rpc/cuda/test_tensorpipe_agent": run_test_with_subprocess, 1113*da0073e9SAndroid Build Coastguard Worker "doctests": run_doctests, 1114*da0073e9SAndroid Build Coastguard Worker "test_ci_sanity_check_fail": run_ci_sanity_check, 1115*da0073e9SAndroid Build Coastguard Worker "test_autoload_enable": test_autoload_enable, 1116*da0073e9SAndroid Build Coastguard Worker "test_autoload_disable": test_autoload_disable, 1117*da0073e9SAndroid Build Coastguard Worker} 1118*da0073e9SAndroid Build Coastguard Worker 1119*da0073e9SAndroid Build Coastguard Worker 1120*da0073e9SAndroid Build Coastguard WorkerPYTEST_SKIP_RETRIES = {"test_public_bindings"} 1121*da0073e9SAndroid Build Coastguard Worker 1122*da0073e9SAndroid Build Coastguard Worker 1123*da0073e9SAndroid Build Coastguard Workerdef parse_args(): 1124*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser( 1125*da0073e9SAndroid Build Coastguard Worker description="Run the PyTorch unit test suite", 1126*da0073e9SAndroid Build Coastguard Worker epilog="where TESTS is any of: {}".format(", ".join(TESTS)), 1127*da0073e9SAndroid Build Coastguard Worker formatter_class=argparse.RawTextHelpFormatter, 1128*da0073e9SAndroid Build Coastguard Worker ) 1129*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1130*da0073e9SAndroid Build Coastguard Worker "-v", 1131*da0073e9SAndroid Build Coastguard Worker "--verbose", 1132*da0073e9SAndroid Build Coastguard Worker action="count", 1133*da0073e9SAndroid Build Coastguard Worker default=0, 1134*da0073e9SAndroid Build Coastguard Worker help="Print verbose information and test-by-test results", 1135*da0073e9SAndroid Build Coastguard Worker ) 1136*da0073e9SAndroid Build Coastguard Worker if sys.version_info >= (3, 9): 1137*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1138*da0073e9SAndroid Build Coastguard Worker "--showlocals", 1139*da0073e9SAndroid Build Coastguard Worker action=argparse.BooleanOptionalAction, 1140*da0073e9SAndroid Build Coastguard Worker default=strtobool(os.environ.get("TEST_SHOWLOCALS", "False")), 1141*da0073e9SAndroid Build Coastguard Worker help="Show local variables in tracebacks (default: True)", 1142*da0073e9SAndroid Build Coastguard Worker ) 1143*da0073e9SAndroid Build Coastguard Worker else: 1144*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1145*da0073e9SAndroid Build Coastguard Worker "--showlocals", 1146*da0073e9SAndroid Build Coastguard Worker action="store_true", 1147*da0073e9SAndroid Build Coastguard Worker default=strtobool(os.environ.get("TEST_SHOWLOCALS", "False")), 1148*da0073e9SAndroid Build Coastguard Worker help="Show local variables in tracebacks (default: True)", 1149*da0073e9SAndroid Build Coastguard Worker ) 1150*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--no-showlocals", dest="showlocals", action="store_false") 1151*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--jit", "--jit", action="store_true", help="run all jit tests") 1152*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1153*da0073e9SAndroid Build Coastguard Worker "--distributed-tests", 1154*da0073e9SAndroid Build Coastguard Worker "--distributed-tests", 1155*da0073e9SAndroid Build Coastguard Worker action="store_true", 1156*da0073e9SAndroid Build Coastguard Worker help="Run all distributed tests", 1157*da0073e9SAndroid Build Coastguard Worker ) 1158*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1159*da0073e9SAndroid Build Coastguard Worker "--functorch", 1160*da0073e9SAndroid Build Coastguard Worker "--functorch", 1161*da0073e9SAndroid Build Coastguard Worker action="store_true", 1162*da0073e9SAndroid Build Coastguard Worker help=( 1163*da0073e9SAndroid Build Coastguard Worker "If this flag is present, we will only run functorch tests. " 1164*da0073e9SAndroid Build Coastguard Worker "If this flag is not present, we will run all tests " 1165*da0073e9SAndroid Build Coastguard Worker "(including functorch tests)." 1166*da0073e9SAndroid Build Coastguard Worker ), 1167*da0073e9SAndroid Build Coastguard Worker ) 1168*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1169*da0073e9SAndroid Build Coastguard Worker "--mps", 1170*da0073e9SAndroid Build Coastguard Worker "--mps", 1171*da0073e9SAndroid Build Coastguard Worker action="store_true", 1172*da0073e9SAndroid Build Coastguard Worker help=("If this flag is present, we will only run test_mps and test_metal"), 1173*da0073e9SAndroid Build Coastguard Worker ) 1174*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1175*da0073e9SAndroid Build Coastguard Worker "--xpu", 1176*da0073e9SAndroid Build Coastguard Worker "--xpu", 1177*da0073e9SAndroid Build Coastguard Worker action="store_true", 1178*da0073e9SAndroid Build Coastguard Worker help=("If this flag is present, we will run xpu tests except XPU_BLOCK_LIST"), 1179*da0073e9SAndroid Build Coastguard Worker ) 1180*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1181*da0073e9SAndroid Build Coastguard Worker "--cpp", 1182*da0073e9SAndroid Build Coastguard Worker "--cpp", 1183*da0073e9SAndroid Build Coastguard Worker action="store_true", 1184*da0073e9SAndroid Build Coastguard Worker help=("If this flag is present, we will only run C++ tests"), 1185*da0073e9SAndroid Build Coastguard Worker ) 1186*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1187*da0073e9SAndroid Build Coastguard Worker "-core", 1188*da0073e9SAndroid Build Coastguard Worker "--core", 1189*da0073e9SAndroid Build Coastguard Worker action="store_true", 1190*da0073e9SAndroid Build Coastguard Worker help="Only run core tests, or tests that validate PyTorch's ops, modules," 1191*da0073e9SAndroid Build Coastguard Worker "and autograd. They are defined by CORE_TEST_LIST.", 1192*da0073e9SAndroid Build Coastguard Worker ) 1193*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1194*da0073e9SAndroid Build Coastguard Worker "--onnx", 1195*da0073e9SAndroid Build Coastguard Worker "--onnx", 1196*da0073e9SAndroid Build Coastguard Worker action="store_true", 1197*da0073e9SAndroid Build Coastguard Worker help=( 1198*da0073e9SAndroid Build Coastguard Worker "Only run ONNX tests, or tests that validate PyTorch's ONNX export. " 1199*da0073e9SAndroid Build Coastguard Worker "If this flag is not present, we will exclude ONNX tests." 1200*da0073e9SAndroid Build Coastguard Worker ), 1201*da0073e9SAndroid Build Coastguard Worker ) 1202*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1203*da0073e9SAndroid Build Coastguard Worker "-k", 1204*da0073e9SAndroid Build Coastguard Worker "--pytest-k-expr", 1205*da0073e9SAndroid Build Coastguard Worker default="", 1206*da0073e9SAndroid Build Coastguard Worker help="Pass to pytest as its -k expr argument", 1207*da0073e9SAndroid Build Coastguard Worker ) 1208*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1209*da0073e9SAndroid Build Coastguard Worker "-c", 1210*da0073e9SAndroid Build Coastguard Worker "--coverage", 1211*da0073e9SAndroid Build Coastguard Worker action="store_true", 1212*da0073e9SAndroid Build Coastguard Worker help="enable coverage", 1213*da0073e9SAndroid Build Coastguard Worker default=PYTORCH_COLLECT_COVERAGE, 1214*da0073e9SAndroid Build Coastguard Worker ) 1215*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1216*da0073e9SAndroid Build Coastguard Worker "-i", 1217*da0073e9SAndroid Build Coastguard Worker "--include", 1218*da0073e9SAndroid Build Coastguard Worker nargs="+", 1219*da0073e9SAndroid Build Coastguard Worker choices=TestChoices(TESTS), 1220*da0073e9SAndroid Build Coastguard Worker default=TESTS, 1221*da0073e9SAndroid Build Coastguard Worker metavar="TESTS", 1222*da0073e9SAndroid Build Coastguard Worker help="select a set of tests to include (defaults to ALL tests)." 1223*da0073e9SAndroid Build Coastguard Worker " tests must be a part of the TESTS list defined in run_test.py", 1224*da0073e9SAndroid Build Coastguard Worker ) 1225*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1226*da0073e9SAndroid Build Coastguard Worker "-x", 1227*da0073e9SAndroid Build Coastguard Worker "--exclude", 1228*da0073e9SAndroid Build Coastguard Worker nargs="+", 1229*da0073e9SAndroid Build Coastguard Worker choices=TESTS, 1230*da0073e9SAndroid Build Coastguard Worker metavar="TESTS", 1231*da0073e9SAndroid Build Coastguard Worker default=[], 1232*da0073e9SAndroid Build Coastguard Worker help="select a set of tests to exclude", 1233*da0073e9SAndroid Build Coastguard Worker ) 1234*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1235*da0073e9SAndroid Build Coastguard Worker "--ignore-win-blocklist", 1236*da0073e9SAndroid Build Coastguard Worker action="store_true", 1237*da0073e9SAndroid Build Coastguard Worker help="always run blocklisted windows tests", 1238*da0073e9SAndroid Build Coastguard Worker ) 1239*da0073e9SAndroid Build Coastguard Worker # NS: Disable target determination until it can be made more reliable 1240*da0073e9SAndroid Build Coastguard Worker # parser.add_argument( 1241*da0073e9SAndroid Build Coastguard Worker # "--determine-from", 1242*da0073e9SAndroid Build Coastguard Worker # help="File of affected source filenames to determine which tests to run.", 1243*da0073e9SAndroid Build Coastguard Worker # ) 1244*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1245*da0073e9SAndroid Build Coastguard Worker "--continue-through-error", 1246*da0073e9SAndroid Build Coastguard Worker "--keep-going", 1247*da0073e9SAndroid Build Coastguard Worker action="store_true", 1248*da0073e9SAndroid Build Coastguard Worker help="Runs the full test suite despite one of the tests failing", 1249*da0073e9SAndroid Build Coastguard Worker default=strtobool(os.environ.get("CONTINUE_THROUGH_ERROR", "False")), 1250*da0073e9SAndroid Build Coastguard Worker ) 1251*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1252*da0073e9SAndroid Build Coastguard Worker "--pipe-logs", 1253*da0073e9SAndroid Build Coastguard Worker action="store_true", 1254*da0073e9SAndroid Build Coastguard Worker help="Print logs to output file while running tests. True if in CI and env var is not set", 1255*da0073e9SAndroid Build Coastguard Worker default=IS_CI and not strtobool(os.environ.get("VERBOSE_TEST_LOGS", "False")), 1256*da0073e9SAndroid Build Coastguard Worker ) 1257*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1258*da0073e9SAndroid Build Coastguard Worker "--enable-timeout", 1259*da0073e9SAndroid Build Coastguard Worker action="store_true", 1260*da0073e9SAndroid Build Coastguard Worker help="Set a timeout based on the test times json file. Only works if there are test times available", 1261*da0073e9SAndroid Build Coastguard Worker default=IS_CI and not strtobool(os.environ.get("NO_TEST_TIMEOUT", "False")), 1262*da0073e9SAndroid Build Coastguard Worker ) 1263*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1264*da0073e9SAndroid Build Coastguard Worker "--enable-td", 1265*da0073e9SAndroid Build Coastguard Worker action="store_true", 1266*da0073e9SAndroid Build Coastguard Worker help="Enables removing tests based on TD", 1267*da0073e9SAndroid Build Coastguard Worker default=IS_CI 1268*da0073e9SAndroid Build Coastguard Worker and ( 1269*da0073e9SAndroid Build Coastguard Worker TEST_WITH_CROSSREF 1270*da0073e9SAndroid Build Coastguard Worker or TEST_WITH_ASAN 1271*da0073e9SAndroid Build Coastguard Worker or (TEST_CONFIG == "distributed" and TEST_CUDA) 1272*da0073e9SAndroid Build Coastguard Worker or (IS_WINDOWS and not TEST_CUDA) 1273*da0073e9SAndroid Build Coastguard Worker or TEST_CONFIG == "nogpu_AVX512" 1274*da0073e9SAndroid Build Coastguard Worker or TEST_CONFIG == "nogpu_NO_AVX2" 1275*da0073e9SAndroid Build Coastguard Worker or TEST_CONFIG == "default" 1276*da0073e9SAndroid Build Coastguard Worker ) 1277*da0073e9SAndroid Build Coastguard Worker and get_pr_number() is not None 1278*da0073e9SAndroid Build Coastguard Worker and not strtobool(os.environ.get("NO_TD", "False")) 1279*da0073e9SAndroid Build Coastguard Worker and not TEST_WITH_ROCM 1280*da0073e9SAndroid Build Coastguard Worker and not IS_MACOS 1281*da0073e9SAndroid Build Coastguard Worker and "xpu" not in BUILD_ENVIRONMENT 1282*da0073e9SAndroid Build Coastguard Worker and "onnx" not in BUILD_ENVIRONMENT 1283*da0073e9SAndroid Build Coastguard Worker and os.environ.get("GITHUB_WORKFLOW", "slow") in ("trunk", "pull"), 1284*da0073e9SAndroid Build Coastguard Worker ) 1285*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1286*da0073e9SAndroid Build Coastguard Worker "--shard", 1287*da0073e9SAndroid Build Coastguard Worker nargs=2, 1288*da0073e9SAndroid Build Coastguard Worker type=int, 1289*da0073e9SAndroid Build Coastguard Worker help="runs a shard of the tests (taking into account other selections), e.g., " 1290*da0073e9SAndroid Build Coastguard Worker "--shard 2 3 will break up the selected tests into 3 shards and run the tests " 1291*da0073e9SAndroid Build Coastguard Worker "in the 2nd shard (the first number should not exceed the second)", 1292*da0073e9SAndroid Build Coastguard Worker ) 1293*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1294*da0073e9SAndroid Build Coastguard Worker "--exclude-jit-executor", 1295*da0073e9SAndroid Build Coastguard Worker action="store_true", 1296*da0073e9SAndroid Build Coastguard Worker help="exclude tests that are run for a specific jit config", 1297*da0073e9SAndroid Build Coastguard Worker ) 1298*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1299*da0073e9SAndroid Build Coastguard Worker "--exclude-torch-export-tests", 1300*da0073e9SAndroid Build Coastguard Worker action="store_true", 1301*da0073e9SAndroid Build Coastguard Worker help="exclude torch export tests", 1302*da0073e9SAndroid Build Coastguard Worker ) 1303*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1304*da0073e9SAndroid Build Coastguard Worker "--exclude-distributed-tests", 1305*da0073e9SAndroid Build Coastguard Worker action="store_true", 1306*da0073e9SAndroid Build Coastguard Worker help="exclude distributed tests", 1307*da0073e9SAndroid Build Coastguard Worker ) 1308*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1309*da0073e9SAndroid Build Coastguard Worker "--exclude-inductor-tests", 1310*da0073e9SAndroid Build Coastguard Worker action="store_true", 1311*da0073e9SAndroid Build Coastguard Worker help="exclude inductor tests", 1312*da0073e9SAndroid Build Coastguard Worker ) 1313*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1314*da0073e9SAndroid Build Coastguard Worker "--dry-run", 1315*da0073e9SAndroid Build Coastguard Worker action="store_true", 1316*da0073e9SAndroid Build Coastguard Worker help="Only list the test that will run.", 1317*da0073e9SAndroid Build Coastguard Worker ) 1318*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1319*da0073e9SAndroid Build Coastguard Worker "--xdoctest-command", 1320*da0073e9SAndroid Build Coastguard Worker default="all", 1321*da0073e9SAndroid Build Coastguard Worker help=( 1322*da0073e9SAndroid Build Coastguard Worker "Control the specific doctest action. " 1323*da0073e9SAndroid Build Coastguard Worker "Use 'list' to simply parse doctests and check syntax. " 1324*da0073e9SAndroid Build Coastguard Worker "Use 'all' to execute all doctests or specify a specific " 1325*da0073e9SAndroid Build Coastguard Worker "doctest to run" 1326*da0073e9SAndroid Build Coastguard Worker ), 1327*da0073e9SAndroid Build Coastguard Worker ) 1328*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1329*da0073e9SAndroid Build Coastguard Worker "--no-translation-validation", 1330*da0073e9SAndroid Build Coastguard Worker action="store_false", 1331*da0073e9SAndroid Build Coastguard Worker help="Run tests without translation validation.", 1332*da0073e9SAndroid Build Coastguard Worker ) 1333*da0073e9SAndroid Build Coastguard Worker 1334*da0073e9SAndroid Build Coastguard Worker group = parser.add_mutually_exclusive_group() 1335*da0073e9SAndroid Build Coastguard Worker group.add_argument( 1336*da0073e9SAndroid Build Coastguard Worker "--dynamo", 1337*da0073e9SAndroid Build Coastguard Worker action="store_true", 1338*da0073e9SAndroid Build Coastguard Worker help="Run tests with TorchDynamo+EagerBackend turned on", 1339*da0073e9SAndroid Build Coastguard Worker ) 1340*da0073e9SAndroid Build Coastguard Worker group.add_argument( 1341*da0073e9SAndroid Build Coastguard Worker "--inductor", 1342*da0073e9SAndroid Build Coastguard Worker action="store_true", 1343*da0073e9SAndroid Build Coastguard Worker help="Run tests with TorchInductor turned on", 1344*da0073e9SAndroid Build Coastguard Worker ) 1345*da0073e9SAndroid Build Coastguard Worker 1346*da0073e9SAndroid Build Coastguard Worker args, extra = parser.parse_known_args() 1347*da0073e9SAndroid Build Coastguard Worker if "--" in extra: 1348*da0073e9SAndroid Build Coastguard Worker extra.remove("--") 1349*da0073e9SAndroid Build Coastguard Worker args.additional_args = extra 1350*da0073e9SAndroid Build Coastguard Worker return args 1351*da0073e9SAndroid Build Coastguard Worker 1352*da0073e9SAndroid Build Coastguard Worker 1353*da0073e9SAndroid Build Coastguard Workerdef exclude_tests( 1354*da0073e9SAndroid Build Coastguard Worker exclude_list, selected_tests, exclude_message=None, exact_match=False 1355*da0073e9SAndroid Build Coastguard Worker): 1356*da0073e9SAndroid Build Coastguard Worker for exclude_test in exclude_list: 1357*da0073e9SAndroid Build Coastguard Worker tests_copy = selected_tests[:] 1358*da0073e9SAndroid Build Coastguard Worker for test in tests_copy: 1359*da0073e9SAndroid Build Coastguard Worker if ( 1360*da0073e9SAndroid Build Coastguard Worker not exact_match and test.startswith(exclude_test) 1361*da0073e9SAndroid Build Coastguard Worker ) or test == exclude_test: 1362*da0073e9SAndroid Build Coastguard Worker if exclude_message is not None: 1363*da0073e9SAndroid Build Coastguard Worker print_to_stderr(f"Excluding {test} {exclude_message}") 1364*da0073e9SAndroid Build Coastguard Worker selected_tests.remove(test) 1365*da0073e9SAndroid Build Coastguard Worker return selected_tests 1366*da0073e9SAndroid Build Coastguard Worker 1367*da0073e9SAndroid Build Coastguard Worker 1368*da0073e9SAndroid Build Coastguard Workerdef must_serial(file: Union[str, ShardedTest]) -> bool: 1369*da0073e9SAndroid Build Coastguard Worker if isinstance(file, ShardedTest): 1370*da0073e9SAndroid Build Coastguard Worker file = file.name 1371*da0073e9SAndroid Build Coastguard Worker return ( 1372*da0073e9SAndroid Build Coastguard Worker os.getenv("PYTORCH_TEST_RUN_EVERYTHING_IN_SERIAL", "0") == "1" 1373*da0073e9SAndroid Build Coastguard Worker or DISTRIBUTED_TEST_PREFIX in os.getenv("TEST_CONFIG", "") 1374*da0073e9SAndroid Build Coastguard Worker or DISTRIBUTED_TEST_PREFIX in file 1375*da0073e9SAndroid Build Coastguard Worker or file in CUSTOM_HANDLERS 1376*da0073e9SAndroid Build Coastguard Worker or file in RUN_PARALLEL_BLOCKLIST 1377*da0073e9SAndroid Build Coastguard Worker or file in CI_SERIAL_LIST 1378*da0073e9SAndroid Build Coastguard Worker or file in JIT_EXECUTOR_TESTS 1379*da0073e9SAndroid Build Coastguard Worker or file in ONNX_SERIAL_LIST 1380*da0073e9SAndroid Build Coastguard Worker or NUM_PROCS == 1 1381*da0073e9SAndroid Build Coastguard Worker ) 1382*da0073e9SAndroid Build Coastguard Worker 1383*da0073e9SAndroid Build Coastguard Worker 1384*da0073e9SAndroid Build Coastguard Workerdef can_run_in_pytest(test): 1385*da0073e9SAndroid Build Coastguard Worker return os.getenv("PYTORCH_TEST_DO_NOT_USE_PYTEST", "0") == "0" 1386*da0073e9SAndroid Build Coastguard Worker 1387*da0073e9SAndroid Build Coastguard Worker 1388*da0073e9SAndroid Build Coastguard Workerdef get_selected_tests(options) -> List[str]: 1389*da0073e9SAndroid Build Coastguard Worker selected_tests = options.include 1390*da0073e9SAndroid Build Coastguard Worker 1391*da0073e9SAndroid Build Coastguard Worker # filter if there's JIT only and distributed only test options 1392*da0073e9SAndroid Build Coastguard Worker if options.jit: 1393*da0073e9SAndroid Build Coastguard Worker selected_tests = list( 1394*da0073e9SAndroid Build Coastguard Worker filter(lambda test_name: "jit" in test_name, selected_tests) 1395*da0073e9SAndroid Build Coastguard Worker ) 1396*da0073e9SAndroid Build Coastguard Worker 1397*da0073e9SAndroid Build Coastguard Worker if options.distributed_tests: 1398*da0073e9SAndroid Build Coastguard Worker selected_tests = list( 1399*da0073e9SAndroid Build Coastguard Worker filter(lambda test_name: test_name in DISTRIBUTED_TESTS, selected_tests) 1400*da0073e9SAndroid Build Coastguard Worker ) 1401*da0073e9SAndroid Build Coastguard Worker 1402*da0073e9SAndroid Build Coastguard Worker # Filter to only run core tests when --core option is specified 1403*da0073e9SAndroid Build Coastguard Worker if options.core: 1404*da0073e9SAndroid Build Coastguard Worker selected_tests = list( 1405*da0073e9SAndroid Build Coastguard Worker filter(lambda test_name: test_name in CORE_TEST_LIST, selected_tests) 1406*da0073e9SAndroid Build Coastguard Worker ) 1407*da0073e9SAndroid Build Coastguard Worker 1408*da0073e9SAndroid Build Coastguard Worker # Filter to only run functorch tests when --functorch option is specified 1409*da0073e9SAndroid Build Coastguard Worker if options.functorch: 1410*da0073e9SAndroid Build Coastguard Worker selected_tests = [tname for tname in selected_tests if tname in FUNCTORCH_TESTS] 1411*da0073e9SAndroid Build Coastguard Worker 1412*da0073e9SAndroid Build Coastguard Worker if options.cpp: 1413*da0073e9SAndroid Build Coastguard Worker selected_tests = [tname for tname in selected_tests if tname in CPP_TESTS] 1414*da0073e9SAndroid Build Coastguard Worker else: 1415*da0073e9SAndroid Build Coastguard Worker # Exclude all C++ tests otherwise as they are still handled differently 1416*da0073e9SAndroid Build Coastguard Worker # than Python test at the moment 1417*da0073e9SAndroid Build Coastguard Worker options.exclude.extend(CPP_TESTS) 1418*da0073e9SAndroid Build Coastguard Worker 1419*da0073e9SAndroid Build Coastguard Worker if options.mps: 1420*da0073e9SAndroid Build Coastguard Worker selected_tests = ["test_mps", "test_metal", "test_modules", "test_nn"] 1421*da0073e9SAndroid Build Coastguard Worker else: 1422*da0073e9SAndroid Build Coastguard Worker # Exclude all mps tests otherwise 1423*da0073e9SAndroid Build Coastguard Worker options.exclude.extend(["test_mps", "test_metal"]) 1424*da0073e9SAndroid Build Coastguard Worker 1425*da0073e9SAndroid Build Coastguard Worker if options.xpu: 1426*da0073e9SAndroid Build Coastguard Worker selected_tests = exclude_tests(XPU_BLOCKLIST, selected_tests, "on XPU") 1427*da0073e9SAndroid Build Coastguard Worker else: 1428*da0073e9SAndroid Build Coastguard Worker # Exclude all xpu specifc tests otherwise 1429*da0073e9SAndroid Build Coastguard Worker options.exclude.extend(XPU_TEST) 1430*da0073e9SAndroid Build Coastguard Worker 1431*da0073e9SAndroid Build Coastguard Worker # Filter to only run onnx tests when --onnx option is specified 1432*da0073e9SAndroid Build Coastguard Worker onnx_tests = [tname for tname in selected_tests if tname in ONNX_TESTS] 1433*da0073e9SAndroid Build Coastguard Worker if options.onnx: 1434*da0073e9SAndroid Build Coastguard Worker selected_tests = onnx_tests 1435*da0073e9SAndroid Build Coastguard Worker else: 1436*da0073e9SAndroid Build Coastguard Worker # Exclude all onnx tests otherwise 1437*da0073e9SAndroid Build Coastguard Worker options.exclude.extend(onnx_tests) 1438*da0073e9SAndroid Build Coastguard Worker 1439*da0073e9SAndroid Build Coastguard Worker # process exclusion 1440*da0073e9SAndroid Build Coastguard Worker if options.exclude_jit_executor: 1441*da0073e9SAndroid Build Coastguard Worker options.exclude.extend(JIT_EXECUTOR_TESTS) 1442*da0073e9SAndroid Build Coastguard Worker 1443*da0073e9SAndroid Build Coastguard Worker if options.exclude_distributed_tests: 1444*da0073e9SAndroid Build Coastguard Worker options.exclude.extend(DISTRIBUTED_TESTS) 1445*da0073e9SAndroid Build Coastguard Worker 1446*da0073e9SAndroid Build Coastguard Worker if options.exclude_inductor_tests: 1447*da0073e9SAndroid Build Coastguard Worker options.exclude.extend(INDUCTOR_TESTS) 1448*da0073e9SAndroid Build Coastguard Worker 1449*da0073e9SAndroid Build Coastguard Worker if options.exclude_torch_export_tests: 1450*da0073e9SAndroid Build Coastguard Worker options.exclude.extend(TORCH_EXPORT_TESTS) 1451*da0073e9SAndroid Build Coastguard Worker 1452*da0073e9SAndroid Build Coastguard Worker # these tests failing in CUDA 11.6 temporary disabling. issue https://github.com/pytorch/pytorch/issues/75375 1453*da0073e9SAndroid Build Coastguard Worker if torch.version.cuda is not None: 1454*da0073e9SAndroid Build Coastguard Worker options.exclude.extend(["distributions/test_constraints"]) 1455*da0073e9SAndroid Build Coastguard Worker 1456*da0073e9SAndroid Build Coastguard Worker # these tests failing in Python 3.12 temporarily disabling 1457*da0073e9SAndroid Build Coastguard Worker if sys.version_info >= (3, 12): 1458*da0073e9SAndroid Build Coastguard Worker options.exclude.extend( 1459*da0073e9SAndroid Build Coastguard Worker [ 1460*da0073e9SAndroid Build Coastguard Worker "functorch/test_dims", 1461*da0073e9SAndroid Build Coastguard Worker "functorch/test_rearrange", 1462*da0073e9SAndroid Build Coastguard Worker "functorch/test_parsing", 1463*da0073e9SAndroid Build Coastguard Worker "functorch/test_memory_efficient_fusion", 1464*da0073e9SAndroid Build Coastguard Worker "torch_np/numpy_tests/core/test_multiarray", 1465*da0073e9SAndroid Build Coastguard Worker ] 1466*da0073e9SAndroid Build Coastguard Worker ) 1467*da0073e9SAndroid Build Coastguard Worker 1468*da0073e9SAndroid Build Coastguard Worker selected_tests = exclude_tests(options.exclude, selected_tests) 1469*da0073e9SAndroid Build Coastguard Worker 1470*da0073e9SAndroid Build Coastguard Worker if sys.platform == "win32" and not options.ignore_win_blocklist: 1471*da0073e9SAndroid Build Coastguard Worker target_arch = os.environ.get("VSCMD_ARG_TGT_ARCH") 1472*da0073e9SAndroid Build Coastguard Worker if target_arch != "x64": 1473*da0073e9SAndroid Build Coastguard Worker WINDOWS_BLOCKLIST.append("cpp_extensions_aot_no_ninja") 1474*da0073e9SAndroid Build Coastguard Worker WINDOWS_BLOCKLIST.append("cpp_extensions_aot_ninja") 1475*da0073e9SAndroid Build Coastguard Worker WINDOWS_BLOCKLIST.append("cpp_extensions_jit") 1476*da0073e9SAndroid Build Coastguard Worker WINDOWS_BLOCKLIST.append("jit") 1477*da0073e9SAndroid Build Coastguard Worker WINDOWS_BLOCKLIST.append("jit_fuser") 1478*da0073e9SAndroid Build Coastguard Worker 1479*da0073e9SAndroid Build Coastguard Worker selected_tests = exclude_tests(WINDOWS_BLOCKLIST, selected_tests, "on Windows") 1480*da0073e9SAndroid Build Coastguard Worker 1481*da0073e9SAndroid Build Coastguard Worker elif TEST_WITH_ROCM: 1482*da0073e9SAndroid Build Coastguard Worker selected_tests = exclude_tests(ROCM_BLOCKLIST, selected_tests, "on ROCm") 1483*da0073e9SAndroid Build Coastguard Worker 1484*da0073e9SAndroid Build Coastguard Worker # skip all distributed tests if distributed package is not available. 1485*da0073e9SAndroid Build Coastguard Worker if not dist.is_available(): 1486*da0073e9SAndroid Build Coastguard Worker selected_tests = exclude_tests( 1487*da0073e9SAndroid Build Coastguard Worker DISTRIBUTED_TESTS, 1488*da0073e9SAndroid Build Coastguard Worker selected_tests, 1489*da0073e9SAndroid Build Coastguard Worker "PyTorch is built without distributed support.", 1490*da0073e9SAndroid Build Coastguard Worker ) 1491*da0073e9SAndroid Build Coastguard Worker 1492*da0073e9SAndroid Build Coastguard Worker # skip tests that require LAPACK when it's not available 1493*da0073e9SAndroid Build Coastguard Worker if not torch._C.has_lapack: 1494*da0073e9SAndroid Build Coastguard Worker selected_tests = exclude_tests( 1495*da0073e9SAndroid Build Coastguard Worker TESTS_REQUIRING_LAPACK, 1496*da0073e9SAndroid Build Coastguard Worker selected_tests, 1497*da0073e9SAndroid Build Coastguard Worker "PyTorch is built without LAPACK support.", 1498*da0073e9SAndroid Build Coastguard Worker ) 1499*da0073e9SAndroid Build Coastguard Worker 1500*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_SLOW_GRADCHECK: 1501*da0073e9SAndroid Build Coastguard Worker selected_tests = exclude_tests( 1502*da0073e9SAndroid Build Coastguard Worker TESTS_NOT_USING_GRADCHECK, 1503*da0073e9SAndroid Build Coastguard Worker selected_tests, 1504*da0073e9SAndroid Build Coastguard Worker "Running in slow gradcheck mode, skipping tests " 1505*da0073e9SAndroid Build Coastguard Worker "that don't use gradcheck.", 1506*da0073e9SAndroid Build Coastguard Worker exact_match=True, 1507*da0073e9SAndroid Build Coastguard Worker ) 1508*da0073e9SAndroid Build Coastguard Worker 1509*da0073e9SAndroid Build Coastguard Worker selected_tests = [parse_test_module(x) for x in selected_tests] 1510*da0073e9SAndroid Build Coastguard Worker return selected_tests 1511*da0073e9SAndroid Build Coastguard Worker 1512*da0073e9SAndroid Build Coastguard Worker 1513*da0073e9SAndroid Build Coastguard Workerdef load_test_times_from_file(file: str) -> Dict[str, Any]: 1514*da0073e9SAndroid Build Coastguard Worker # Load previous test times to make sharding decisions 1515*da0073e9SAndroid Build Coastguard Worker path = os.path.join(str(REPO_ROOT), file) 1516*da0073e9SAndroid Build Coastguard Worker if not os.path.exists(path): 1517*da0073e9SAndroid Build Coastguard Worker print_to_stderr( 1518*da0073e9SAndroid Build Coastguard Worker f"::warning:: Failed to find test times file `{path}`. Using round robin sharding." 1519*da0073e9SAndroid Build Coastguard Worker ) 1520*da0073e9SAndroid Build Coastguard Worker return {} 1521*da0073e9SAndroid Build Coastguard Worker 1522*da0073e9SAndroid Build Coastguard Worker with open(path) as f: 1523*da0073e9SAndroid Build Coastguard Worker test_times_file = cast(Dict[str, Any], json.load(f)) 1524*da0073e9SAndroid Build Coastguard Worker build_environment = os.environ.get("BUILD_ENVIRONMENT") 1525*da0073e9SAndroid Build Coastguard Worker test_config = os.environ.get("TEST_CONFIG") 1526*da0073e9SAndroid Build Coastguard Worker if test_config in test_times_file.get(build_environment, {}): 1527*da0073e9SAndroid Build Coastguard Worker print_to_stderr("Found test times from artifacts") 1528*da0073e9SAndroid Build Coastguard Worker return test_times_file[build_environment][test_config] 1529*da0073e9SAndroid Build Coastguard Worker elif test_config in test_times_file["default"]: 1530*da0073e9SAndroid Build Coastguard Worker print_to_stderr( 1531*da0073e9SAndroid Build Coastguard Worker f"::warning:: Gathered no stats from artifacts for {build_environment} build env" 1532*da0073e9SAndroid Build Coastguard Worker f" and {test_config} test config. Using default build env and {test_config} test config instead." 1533*da0073e9SAndroid Build Coastguard Worker ) 1534*da0073e9SAndroid Build Coastguard Worker return test_times_file["default"][test_config] 1535*da0073e9SAndroid Build Coastguard Worker else: 1536*da0073e9SAndroid Build Coastguard Worker print_to_stderr( 1537*da0073e9SAndroid Build Coastguard Worker f"::warning:: Gathered no stats from artifacts for build env {build_environment} build env" 1538*da0073e9SAndroid Build Coastguard Worker f" and {test_config} test config. Using default build env and default test config instead." 1539*da0073e9SAndroid Build Coastguard Worker ) 1540*da0073e9SAndroid Build Coastguard Worker return test_times_file["default"]["default"] 1541*da0073e9SAndroid Build Coastguard Worker 1542*da0073e9SAndroid Build Coastguard Worker 1543*da0073e9SAndroid Build Coastguard Workerdef load_test_file_times( 1544*da0073e9SAndroid Build Coastguard Worker file: str = ADDITIONAL_CI_FILES_FOLDER / TEST_TIMES_FILE, 1545*da0073e9SAndroid Build Coastguard Worker) -> Dict[str, float]: 1546*da0073e9SAndroid Build Coastguard Worker return cast(Dict[str, float], load_test_times_from_file(file)) 1547*da0073e9SAndroid Build Coastguard Worker 1548*da0073e9SAndroid Build Coastguard Worker 1549*da0073e9SAndroid Build Coastguard Workerdef load_test_class_times( 1550*da0073e9SAndroid Build Coastguard Worker file: str = ADDITIONAL_CI_FILES_FOLDER / TEST_CLASS_TIMES_FILE, 1551*da0073e9SAndroid Build Coastguard Worker) -> Dict[str, Dict[str, float]]: 1552*da0073e9SAndroid Build Coastguard Worker return cast(Dict[str, Dict[str, float]], load_test_times_from_file(file)) 1553*da0073e9SAndroid Build Coastguard Worker 1554*da0073e9SAndroid Build Coastguard Worker 1555*da0073e9SAndroid Build Coastguard Workerdef get_sharding_opts(options) -> Tuple[int, int]: 1556*da0073e9SAndroid Build Coastguard Worker which_shard, num_shards = 1, 1 1557*da0073e9SAndroid Build Coastguard Worker if options.shard: 1558*da0073e9SAndroid Build Coastguard Worker assert len(options.shard) == 2, "Unexpected shard format" 1559*da0073e9SAndroid Build Coastguard Worker assert min(options.shard) > 0, "Shards must be positive numbers" 1560*da0073e9SAndroid Build Coastguard Worker which_shard, num_shards = options.shard 1561*da0073e9SAndroid Build Coastguard Worker assert ( 1562*da0073e9SAndroid Build Coastguard Worker which_shard <= num_shards 1563*da0073e9SAndroid Build Coastguard Worker ), "Selected shard must be less than or equal to total number of shards" 1564*da0073e9SAndroid Build Coastguard Worker 1565*da0073e9SAndroid Build Coastguard Worker return (which_shard, num_shards) 1566*da0073e9SAndroid Build Coastguard Worker 1567*da0073e9SAndroid Build Coastguard Worker 1568*da0073e9SAndroid Build Coastguard Workerdef do_sharding( 1569*da0073e9SAndroid Build Coastguard Worker options, 1570*da0073e9SAndroid Build Coastguard Worker selected_tests: Sequence[TestRun], 1571*da0073e9SAndroid Build Coastguard Worker test_file_times: Dict[str, float], 1572*da0073e9SAndroid Build Coastguard Worker test_class_times: Dict[str, Dict[str, float]], 1573*da0073e9SAndroid Build Coastguard Worker sort_by_time: bool = True, 1574*da0073e9SAndroid Build Coastguard Worker) -> Tuple[float, List[ShardedTest]]: 1575*da0073e9SAndroid Build Coastguard Worker which_shard, num_shards = get_sharding_opts(options) 1576*da0073e9SAndroid Build Coastguard Worker 1577*da0073e9SAndroid Build Coastguard Worker # Do sharding 1578*da0073e9SAndroid Build Coastguard Worker shards = calculate_shards( 1579*da0073e9SAndroid Build Coastguard Worker num_shards, 1580*da0073e9SAndroid Build Coastguard Worker selected_tests, 1581*da0073e9SAndroid Build Coastguard Worker test_file_times, 1582*da0073e9SAndroid Build Coastguard Worker test_class_times=test_class_times, 1583*da0073e9SAndroid Build Coastguard Worker must_serial=must_serial, 1584*da0073e9SAndroid Build Coastguard Worker sort_by_time=sort_by_time, 1585*da0073e9SAndroid Build Coastguard Worker ) 1586*da0073e9SAndroid Build Coastguard Worker return shards[which_shard - 1] 1587*da0073e9SAndroid Build Coastguard Worker 1588*da0073e9SAndroid Build Coastguard Worker 1589*da0073e9SAndroid Build Coastguard Workerclass TestFailure(NamedTuple): 1590*da0073e9SAndroid Build Coastguard Worker test: TestRun 1591*da0073e9SAndroid Build Coastguard Worker message: str 1592*da0073e9SAndroid Build Coastguard Worker 1593*da0073e9SAndroid Build Coastguard Worker 1594*da0073e9SAndroid Build Coastguard Workerdef run_test_module( 1595*da0073e9SAndroid Build Coastguard Worker test: ShardedTest, test_directory: str, options 1596*da0073e9SAndroid Build Coastguard Worker) -> Optional[TestFailure]: 1597*da0073e9SAndroid Build Coastguard Worker try: 1598*da0073e9SAndroid Build Coastguard Worker maybe_set_hip_visible_devies() 1599*da0073e9SAndroid Build Coastguard Worker 1600*da0073e9SAndroid Build Coastguard Worker test_name = test.name 1601*da0073e9SAndroid Build Coastguard Worker 1602*da0073e9SAndroid Build Coastguard Worker # Printing the date here can help diagnose which tests are slow 1603*da0073e9SAndroid Build Coastguard Worker print_to_stderr(f"Running {str(test)} ... [{datetime.now()}]") 1604*da0073e9SAndroid Build Coastguard Worker handler = CUSTOM_HANDLERS.get(test_name, run_test) 1605*da0073e9SAndroid Build Coastguard Worker return_code = handler(test, test_directory, options) 1606*da0073e9SAndroid Build Coastguard Worker assert isinstance(return_code, int) and not isinstance( 1607*da0073e9SAndroid Build Coastguard Worker return_code, bool 1608*da0073e9SAndroid Build Coastguard Worker ), f"While running {str(test)} got non integer return code {return_code}" 1609*da0073e9SAndroid Build Coastguard Worker if return_code == 0: 1610*da0073e9SAndroid Build Coastguard Worker return None 1611*da0073e9SAndroid Build Coastguard Worker 1612*da0073e9SAndroid Build Coastguard Worker message = f"{str(test)} failed!" 1613*da0073e9SAndroid Build Coastguard Worker if return_code < 0: 1614*da0073e9SAndroid Build Coastguard Worker # subprocess.Popen returns the child process' exit signal as 1615*da0073e9SAndroid Build Coastguard Worker # return code -N, where N is the signal number. 1616*da0073e9SAndroid Build Coastguard Worker signal_name = SIGNALS_TO_NAMES_DICT[-return_code] 1617*da0073e9SAndroid Build Coastguard Worker message += f" Received signal: {signal_name}" 1618*da0073e9SAndroid Build Coastguard Worker return TestFailure(test.test, message) 1619*da0073e9SAndroid Build Coastguard Worker except Exception as e: 1620*da0073e9SAndroid Build Coastguard Worker return TestFailure(test.test, f"{str(test)} failed! {e}") 1621*da0073e9SAndroid Build Coastguard Worker 1622*da0073e9SAndroid Build Coastguard Worker 1623*da0073e9SAndroid Build Coastguard Workerdef run_tests( 1624*da0073e9SAndroid Build Coastguard Worker selected_tests: List[ShardedTest], 1625*da0073e9SAndroid Build Coastguard Worker test_directory: str, 1626*da0073e9SAndroid Build Coastguard Worker options, 1627*da0073e9SAndroid Build Coastguard Worker failures: List[TestFailure], 1628*da0073e9SAndroid Build Coastguard Worker) -> None: 1629*da0073e9SAndroid Build Coastguard Worker if len(selected_tests) == 0: 1630*da0073e9SAndroid Build Coastguard Worker return 1631*da0073e9SAndroid Build Coastguard Worker 1632*da0073e9SAndroid Build Coastguard Worker # parallel = in parallel with other files 1633*da0073e9SAndroid Build Coastguard Worker # serial = this file on it's own. The file might still be run in parallel with itself (ex test_ops) 1634*da0073e9SAndroid Build Coastguard Worker selected_tests_parallel = [x for x in selected_tests if not must_serial(x)] 1635*da0073e9SAndroid Build Coastguard Worker selected_tests_serial = [ 1636*da0073e9SAndroid Build Coastguard Worker x for x in selected_tests if x not in selected_tests_parallel 1637*da0073e9SAndroid Build Coastguard Worker ] 1638*da0073e9SAndroid Build Coastguard Worker 1639*da0073e9SAndroid Build Coastguard Worker # See Note [ROCm parallel CI testing] 1640*da0073e9SAndroid Build Coastguard Worker pool = get_context("spawn").Pool( 1641*da0073e9SAndroid Build Coastguard Worker NUM_PROCS, maxtasksperchild=None if torch.version.hip else 1 1642*da0073e9SAndroid Build Coastguard Worker ) 1643*da0073e9SAndroid Build Coastguard Worker 1644*da0073e9SAndroid Build Coastguard Worker # NB: This is a hack to make conftest.py and files it depends on available 1645*da0073e9SAndroid Build Coastguard Worker # on CPP_TESTS_DIR. We should see if the file could be turned into a 1646*da0073e9SAndroid Build Coastguard Worker # full-fledge ptest plugin instead 1647*da0073e9SAndroid Build Coastguard Worker conftest_files = [ 1648*da0073e9SAndroid Build Coastguard Worker "conftest.py", 1649*da0073e9SAndroid Build Coastguard Worker "pytest_shard_custom.py", 1650*da0073e9SAndroid Build Coastguard Worker ] 1651*da0073e9SAndroid Build Coastguard Worker for conftest_file in conftest_files: 1652*da0073e9SAndroid Build Coastguard Worker cpp_file = os.path.join(CPP_TESTS_DIR, conftest_file) 1653*da0073e9SAndroid Build Coastguard Worker if ( 1654*da0073e9SAndroid Build Coastguard Worker options.cpp 1655*da0073e9SAndroid Build Coastguard Worker and os.path.exists(CPP_TESTS_DIR) 1656*da0073e9SAndroid Build Coastguard Worker and os.path.isdir(CPP_TESTS_DIR) 1657*da0073e9SAndroid Build Coastguard Worker and not os.path.exists(cpp_file) 1658*da0073e9SAndroid Build Coastguard Worker ): 1659*da0073e9SAndroid Build Coastguard Worker shutil.copy(os.path.join(test_directory, conftest_file), cpp_file) 1660*da0073e9SAndroid Build Coastguard Worker 1661*da0073e9SAndroid Build Coastguard Worker def handle_error_messages(failure: Optional[TestFailure]): 1662*da0073e9SAndroid Build Coastguard Worker if failure is None: 1663*da0073e9SAndroid Build Coastguard Worker return False 1664*da0073e9SAndroid Build Coastguard Worker failures.append(failure) 1665*da0073e9SAndroid Build Coastguard Worker print_to_stderr(failure.message) 1666*da0073e9SAndroid Build Coastguard Worker return True 1667*da0073e9SAndroid Build Coastguard Worker 1668*da0073e9SAndroid Build Coastguard Worker def parallel_test_completion_callback(failure): 1669*da0073e9SAndroid Build Coastguard Worker test_failed = handle_error_messages(failure) 1670*da0073e9SAndroid Build Coastguard Worker if ( 1671*da0073e9SAndroid Build Coastguard Worker test_failed 1672*da0073e9SAndroid Build Coastguard Worker and not options.continue_through_error 1673*da0073e9SAndroid Build Coastguard Worker and not RERUN_DISABLED_TESTS 1674*da0073e9SAndroid Build Coastguard Worker ): 1675*da0073e9SAndroid Build Coastguard Worker pool.terminate() 1676*da0073e9SAndroid Build Coastguard Worker 1677*da0073e9SAndroid Build Coastguard Worker keep_going_message = ( 1678*da0073e9SAndroid Build Coastguard Worker "\n\nTip: You can keep running tests even on failure by passing --keep-going to run_test.py.\n" 1679*da0073e9SAndroid Build Coastguard Worker "If running on CI, add the 'keep-going' label to your PR and rerun your jobs." 1680*da0073e9SAndroid Build Coastguard Worker ) 1681*da0073e9SAndroid Build Coastguard Worker 1682*da0073e9SAndroid Build Coastguard Worker try: 1683*da0073e9SAndroid Build Coastguard Worker for test in selected_tests_serial: 1684*da0073e9SAndroid Build Coastguard Worker options_clone = copy.deepcopy(options) 1685*da0073e9SAndroid Build Coastguard Worker if can_run_in_pytest(test): 1686*da0073e9SAndroid Build Coastguard Worker options_clone.pytest = True 1687*da0073e9SAndroid Build Coastguard Worker failure = run_test_module(test, test_directory, options_clone) 1688*da0073e9SAndroid Build Coastguard Worker test_failed = handle_error_messages(failure) 1689*da0073e9SAndroid Build Coastguard Worker if ( 1690*da0073e9SAndroid Build Coastguard Worker test_failed 1691*da0073e9SAndroid Build Coastguard Worker and not options.continue_through_error 1692*da0073e9SAndroid Build Coastguard Worker and not RERUN_DISABLED_TESTS 1693*da0073e9SAndroid Build Coastguard Worker ): 1694*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(failure.message + keep_going_message) 1695*da0073e9SAndroid Build Coastguard Worker 1696*da0073e9SAndroid Build Coastguard Worker # Run tests marked as serial first 1697*da0073e9SAndroid Build Coastguard Worker for test in selected_tests_parallel: 1698*da0073e9SAndroid Build Coastguard Worker options_clone = copy.deepcopy(options) 1699*da0073e9SAndroid Build Coastguard Worker if can_run_in_pytest(test): 1700*da0073e9SAndroid Build Coastguard Worker options_clone.pytest = True 1701*da0073e9SAndroid Build Coastguard Worker options_clone.additional_args.extend(["-m", "serial"]) 1702*da0073e9SAndroid Build Coastguard Worker failure = run_test_module(test, test_directory, options_clone) 1703*da0073e9SAndroid Build Coastguard Worker test_failed = handle_error_messages(failure) 1704*da0073e9SAndroid Build Coastguard Worker if ( 1705*da0073e9SAndroid Build Coastguard Worker test_failed 1706*da0073e9SAndroid Build Coastguard Worker and not options.continue_through_error 1707*da0073e9SAndroid Build Coastguard Worker and not RERUN_DISABLED_TESTS 1708*da0073e9SAndroid Build Coastguard Worker ): 1709*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(failure.message + keep_going_message) 1710*da0073e9SAndroid Build Coastguard Worker 1711*da0073e9SAndroid Build Coastguard Worker os.environ["NUM_PARALLEL_PROCS"] = str(NUM_PROCS) 1712*da0073e9SAndroid Build Coastguard Worker for test in selected_tests_parallel: 1713*da0073e9SAndroid Build Coastguard Worker options_clone = copy.deepcopy(options) 1714*da0073e9SAndroid Build Coastguard Worker if can_run_in_pytest(test): 1715*da0073e9SAndroid Build Coastguard Worker options_clone.pytest = True 1716*da0073e9SAndroid Build Coastguard Worker options_clone.additional_args.extend(["-m", "not serial"]) 1717*da0073e9SAndroid Build Coastguard Worker pool.apply_async( 1718*da0073e9SAndroid Build Coastguard Worker run_test_module, 1719*da0073e9SAndroid Build Coastguard Worker args=(test, test_directory, options_clone), 1720*da0073e9SAndroid Build Coastguard Worker callback=parallel_test_completion_callback, 1721*da0073e9SAndroid Build Coastguard Worker ) 1722*da0073e9SAndroid Build Coastguard Worker pool.close() 1723*da0073e9SAndroid Build Coastguard Worker pool.join() 1724*da0073e9SAndroid Build Coastguard Worker del os.environ["NUM_PARALLEL_PROCS"] 1725*da0073e9SAndroid Build Coastguard Worker 1726*da0073e9SAndroid Build Coastguard Worker finally: 1727*da0073e9SAndroid Build Coastguard Worker pool.terminate() 1728*da0073e9SAndroid Build Coastguard Worker pool.join() 1729*da0073e9SAndroid Build Coastguard Worker 1730*da0073e9SAndroid Build Coastguard Worker return 1731*da0073e9SAndroid Build Coastguard Worker 1732*da0073e9SAndroid Build Coastguard Worker 1733*da0073e9SAndroid Build Coastguard Workerdef check_pip_packages() -> None: 1734*da0073e9SAndroid Build Coastguard Worker packages = [ 1735*da0073e9SAndroid Build Coastguard Worker "pytest-rerunfailures", 1736*da0073e9SAndroid Build Coastguard Worker "pytest-flakefinder", 1737*da0073e9SAndroid Build Coastguard Worker "pytest-xdist", 1738*da0073e9SAndroid Build Coastguard Worker ] 1739*da0073e9SAndroid Build Coastguard Worker installed_packages = [i.key for i in pkg_resources.working_set] 1740*da0073e9SAndroid Build Coastguard Worker for package in packages: 1741*da0073e9SAndroid Build Coastguard Worker if package not in installed_packages: 1742*da0073e9SAndroid Build Coastguard Worker print_to_stderr( 1743*da0073e9SAndroid Build Coastguard Worker f"Missing pip dependency: {package}, please run `pip install -r .ci/docker/requirements-ci.txt`" 1744*da0073e9SAndroid Build Coastguard Worker ) 1745*da0073e9SAndroid Build Coastguard Worker sys.exit(1) 1746*da0073e9SAndroid Build Coastguard Worker 1747*da0073e9SAndroid Build Coastguard Worker 1748*da0073e9SAndroid Build Coastguard Workerdef main(): 1749*da0073e9SAndroid Build Coastguard Worker check_pip_packages() 1750*da0073e9SAndroid Build Coastguard Worker 1751*da0073e9SAndroid Build Coastguard Worker options = parse_args() 1752*da0073e9SAndroid Build Coastguard Worker 1753*da0073e9SAndroid Build Coastguard Worker # Include sharding info in all metrics 1754*da0073e9SAndroid Build Coastguard Worker which_shard, num_shards = get_sharding_opts(options) 1755*da0073e9SAndroid Build Coastguard Worker add_global_metric("shard", which_shard) 1756*da0073e9SAndroid Build Coastguard Worker add_global_metric("num_shards", num_shards) 1757*da0073e9SAndroid Build Coastguard Worker 1758*da0073e9SAndroid Build Coastguard Worker test_directory = str(REPO_ROOT / "test") 1759*da0073e9SAndroid Build Coastguard Worker selected_tests = get_selected_tests(options) 1760*da0073e9SAndroid Build Coastguard Worker 1761*da0073e9SAndroid Build Coastguard Worker test_prioritizations = import_results() 1762*da0073e9SAndroid Build Coastguard Worker test_prioritizations.amend_tests(selected_tests) 1763*da0073e9SAndroid Build Coastguard Worker 1764*da0073e9SAndroid Build Coastguard Worker os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True) 1765*da0073e9SAndroid Build Coastguard Worker 1766*da0073e9SAndroid Build Coastguard Worker if options.coverage and not PYTORCH_COLLECT_COVERAGE: 1767*da0073e9SAndroid Build Coastguard Worker shell(["coverage", "erase"]) 1768*da0073e9SAndroid Build Coastguard Worker 1769*da0073e9SAndroid Build Coastguard Worker if IS_CI: 1770*da0073e9SAndroid Build Coastguard Worker # downloading test cases configuration to local environment 1771*da0073e9SAndroid Build Coastguard Worker get_test_case_configs(dirpath=test_directory) 1772*da0073e9SAndroid Build Coastguard Worker 1773*da0073e9SAndroid Build Coastguard Worker test_file_times_dict = load_test_file_times() 1774*da0073e9SAndroid Build Coastguard Worker test_class_times_dict = load_test_class_times() 1775*da0073e9SAndroid Build Coastguard Worker 1776*da0073e9SAndroid Build Coastguard Worker class TestBatch: 1777*da0073e9SAndroid Build Coastguard Worker """Defines a set of tests with similar priority that should be run together on the current shard""" 1778*da0073e9SAndroid Build Coastguard Worker 1779*da0073e9SAndroid Build Coastguard Worker name: str 1780*da0073e9SAndroid Build Coastguard Worker sharded_tests: List[ShardedTest] 1781*da0073e9SAndroid Build Coastguard Worker failures: List[TestFailure] 1782*da0073e9SAndroid Build Coastguard Worker 1783*da0073e9SAndroid Build Coastguard Worker def __init__( 1784*da0073e9SAndroid Build Coastguard Worker self, name: str, raw_tests: Sequence[TestRun], should_sort_shard: bool 1785*da0073e9SAndroid Build Coastguard Worker ): 1786*da0073e9SAndroid Build Coastguard Worker self.name = name 1787*da0073e9SAndroid Build Coastguard Worker self.failures = [] 1788*da0073e9SAndroid Build Coastguard Worker self.time, self.sharded_tests = do_sharding( 1789*da0073e9SAndroid Build Coastguard Worker options, 1790*da0073e9SAndroid Build Coastguard Worker raw_tests, 1791*da0073e9SAndroid Build Coastguard Worker test_file_times_dict, 1792*da0073e9SAndroid Build Coastguard Worker test_class_times_dict, 1793*da0073e9SAndroid Build Coastguard Worker sort_by_time=should_sort_shard, 1794*da0073e9SAndroid Build Coastguard Worker ) 1795*da0073e9SAndroid Build Coastguard Worker 1796*da0073e9SAndroid Build Coastguard Worker def __str__(self): 1797*da0073e9SAndroid Build Coastguard Worker s = f"Name: {self.name} (est. time: {round(self.time / 60, 2)}min)\n" 1798*da0073e9SAndroid Build Coastguard Worker serial = [test for test in self.sharded_tests if must_serial(test)] 1799*da0073e9SAndroid Build Coastguard Worker parallel = [test for test in self.sharded_tests if not must_serial(test)] 1800*da0073e9SAndroid Build Coastguard Worker s += f" Serial tests ({len(serial)}):\n" 1801*da0073e9SAndroid Build Coastguard Worker s += "".join(f" {test}\n" for test in serial) 1802*da0073e9SAndroid Build Coastguard Worker s += f" Parallel tests ({len(parallel)}):\n" 1803*da0073e9SAndroid Build Coastguard Worker s += "".join(f" {test}\n" for test in parallel) 1804*da0073e9SAndroid Build Coastguard Worker return s.strip() 1805*da0073e9SAndroid Build Coastguard Worker 1806*da0073e9SAndroid Build Coastguard Worker percent_to_run = 25 if options.enable_td else 100 1807*da0073e9SAndroid Build Coastguard Worker print_to_stderr( 1808*da0073e9SAndroid Build Coastguard Worker f"Running {percent_to_run}% of tests based on TD" 1809*da0073e9SAndroid Build Coastguard Worker if options.enable_td 1810*da0073e9SAndroid Build Coastguard Worker else "Running all tests" 1811*da0073e9SAndroid Build Coastguard Worker ) 1812*da0073e9SAndroid Build Coastguard Worker include, exclude = test_prioritizations.get_top_per_tests(percent_to_run) 1813*da0073e9SAndroid Build Coastguard Worker 1814*da0073e9SAndroid Build Coastguard Worker test_batch = TestBatch("tests to run", include, False) 1815*da0073e9SAndroid Build Coastguard Worker test_batch_exclude = TestBatch("excluded", exclude, True) 1816*da0073e9SAndroid Build Coastguard Worker if IS_CI: 1817*da0073e9SAndroid Build Coastguard Worker gen_ci_artifact([x.to_json() for x in include], [x.to_json() for x in exclude]) 1818*da0073e9SAndroid Build Coastguard Worker 1819*da0073e9SAndroid Build Coastguard Worker print_to_stderr(f"Running parallel tests on {NUM_PROCS} processes") 1820*da0073e9SAndroid Build Coastguard Worker print_to_stderr(test_batch) 1821*da0073e9SAndroid Build Coastguard Worker print_to_stderr(test_batch_exclude) 1822*da0073e9SAndroid Build Coastguard Worker 1823*da0073e9SAndroid Build Coastguard Worker if options.dry_run: 1824*da0073e9SAndroid Build Coastguard Worker return 1825*da0073e9SAndroid Build Coastguard Worker 1826*da0073e9SAndroid Build Coastguard Worker if options.dynamo: 1827*da0073e9SAndroid Build Coastguard Worker os.environ["PYTORCH_TEST_WITH_DYNAMO"] = "1" 1828*da0073e9SAndroid Build Coastguard Worker 1829*da0073e9SAndroid Build Coastguard Worker elif options.inductor: 1830*da0073e9SAndroid Build Coastguard Worker os.environ["PYTORCH_TEST_WITH_INDUCTOR"] = "1" 1831*da0073e9SAndroid Build Coastguard Worker 1832*da0073e9SAndroid Build Coastguard Worker if not options.no_translation_validation: 1833*da0073e9SAndroid Build Coastguard Worker os.environ["PYTORCH_TEST_WITH_TV"] = "1" 1834*da0073e9SAndroid Build Coastguard Worker 1835*da0073e9SAndroid Build Coastguard Worker try: 1836*da0073e9SAndroid Build Coastguard Worker # Actually run the tests 1837*da0073e9SAndroid Build Coastguard Worker start_time = time.time() 1838*da0073e9SAndroid Build Coastguard Worker run_tests( 1839*da0073e9SAndroid Build Coastguard Worker test_batch.sharded_tests, test_directory, options, test_batch.failures 1840*da0073e9SAndroid Build Coastguard Worker ) 1841*da0073e9SAndroid Build Coastguard Worker elapsed_time = time.time() - start_time 1842*da0073e9SAndroid Build Coastguard Worker print_to_stderr( 1843*da0073e9SAndroid Build Coastguard Worker f"Running test batch '{test_batch.name}' cost {round(elapsed_time, 2)} seconds" 1844*da0073e9SAndroid Build Coastguard Worker ) 1845*da0073e9SAndroid Build Coastguard Worker 1846*da0073e9SAndroid Build Coastguard Worker finally: 1847*da0073e9SAndroid Build Coastguard Worker if options.coverage: 1848*da0073e9SAndroid Build Coastguard Worker from coverage import Coverage 1849*da0073e9SAndroid Build Coastguard Worker 1850*da0073e9SAndroid Build Coastguard Worker with set_cwd(test_directory): 1851*da0073e9SAndroid Build Coastguard Worker cov = Coverage() 1852*da0073e9SAndroid Build Coastguard Worker if PYTORCH_COLLECT_COVERAGE: 1853*da0073e9SAndroid Build Coastguard Worker cov.load() 1854*da0073e9SAndroid Build Coastguard Worker cov.combine(strict=False) 1855*da0073e9SAndroid Build Coastguard Worker cov.save() 1856*da0073e9SAndroid Build Coastguard Worker if not PYTORCH_COLLECT_COVERAGE: 1857*da0073e9SAndroid Build Coastguard Worker cov.html_report() 1858*da0073e9SAndroid Build Coastguard Worker 1859*da0073e9SAndroid Build Coastguard Worker all_failures = test_batch.failures 1860*da0073e9SAndroid Build Coastguard Worker 1861*da0073e9SAndroid Build Coastguard Worker if IS_CI: 1862*da0073e9SAndroid Build Coastguard Worker for test, _ in all_failures: 1863*da0073e9SAndroid Build Coastguard Worker test_stats = test_prioritizations.get_test_stats(test) 1864*da0073e9SAndroid Build Coastguard Worker print_to_stderr("Emiting td_test_failure_stats_v2") 1865*da0073e9SAndroid Build Coastguard Worker emit_metric( 1866*da0073e9SAndroid Build Coastguard Worker "td_test_failure_stats_v2", 1867*da0073e9SAndroid Build Coastguard Worker { 1868*da0073e9SAndroid Build Coastguard Worker "selected_tests": selected_tests, 1869*da0073e9SAndroid Build Coastguard Worker "failure": str(test), 1870*da0073e9SAndroid Build Coastguard Worker **test_stats, 1871*da0073e9SAndroid Build Coastguard Worker }, 1872*da0073e9SAndroid Build Coastguard Worker ) 1873*da0073e9SAndroid Build Coastguard Worker gen_additional_test_failures_file( 1874*da0073e9SAndroid Build Coastguard Worker [test.test_file for test, _ in all_failures] 1875*da0073e9SAndroid Build Coastguard Worker ) 1876*da0073e9SAndroid Build Coastguard Worker 1877*da0073e9SAndroid Build Coastguard Worker if len(all_failures): 1878*da0073e9SAndroid Build Coastguard Worker for _, err in all_failures: 1879*da0073e9SAndroid Build Coastguard Worker print_to_stderr(err) 1880*da0073e9SAndroid Build Coastguard Worker 1881*da0073e9SAndroid Build Coastguard Worker # A disabled test is expected to fail, so there is no need to report a failure here 1882*da0073e9SAndroid Build Coastguard Worker if not RERUN_DISABLED_TESTS: 1883*da0073e9SAndroid Build Coastguard Worker sys.exit(1) 1884*da0073e9SAndroid Build Coastguard Worker 1885*da0073e9SAndroid Build Coastguard Worker 1886*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 1887*da0073e9SAndroid Build Coastguard Worker main() 1888