1from __future__ import annotations
2
3import json
4import os
5from collections import defaultdict
6from typing import Any, cast, Dict
7from warnings import warn
8
9from tools.stats.import_test_stats import (
10    ADDITIONAL_CI_FILES_FOLDER,
11    TEST_CLASS_RATINGS_FILE,
12)
13from tools.testing.target_determination.heuristics.interface import (
14    HeuristicInterface,
15    TestPrioritizations,
16)
17from tools.testing.target_determination.heuristics.utils import (
18    normalize_ratings,
19    query_changed_files,
20    REPO_ROOT,
21)
22from tools.testing.test_run import TestRun
23
24
25class HistoricalClassFailurCorrelation(HeuristicInterface):
26    """
27    This heuristic prioritizes test classes that have historically tended to fail
28    when the files edited by current PR were modified.
29    """
30
31    def __init__(self, **kwargs: Any) -> None:
32        super().__init__(**kwargs)
33
34    def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations:
35        ratings = _get_ratings_for_tests(set(tests))
36        test_ratings = {
37            TestRun(k): v for (k, v) in ratings.items() if TestRun(k).test_file in tests
38        }
39        return TestPrioritizations(tests, normalize_ratings(test_ratings, 0.25))
40
41
42def _get_historical_test_class_correlations() -> dict[str, dict[str, float]]:
43    path = REPO_ROOT / ADDITIONAL_CI_FILES_FOLDER / TEST_CLASS_RATINGS_FILE
44    if not os.path.exists(path):
45        print(f"could not find path {path}")
46        return {}
47    with open(path) as f:
48        test_class_correlations = cast(Dict[str, Dict[str, float]], json.load(f))
49        return test_class_correlations
50
51
52def _get_ratings_for_tests(
53    tests_to_run: set[str],
54) -> dict[str, float]:
55    # Get the files edited
56    try:
57        changed_files = query_changed_files()
58    except Exception as e:
59        warn(f"Can't query changed test files due to {e}")
60        return {}
61
62    test_class_correlations = _get_historical_test_class_correlations()
63    if not test_class_correlations:
64        return {}
65
66    # Find the tests failures that are correlated with the edited files.
67    # Filter the list to only include tests we want to run.
68    ratings: dict[str, float] = defaultdict(float)
69    for file in changed_files:
70        for qualified_test_class, score in test_class_correlations.get(
71            file, {}
72        ).items():
73            # qualified_test_class looks like "test_file::test_class"
74            test_file, test_class = qualified_test_class.split("::")
75            if test_file in tests_to_run:
76                ratings[qualified_test_class] += score
77
78    return ratings
79
80
81def _rank_correlated_tests(
82    tests_to_run: list[str],
83) -> list[str]:
84    # Find the tests failures that are correlated with the edited files.
85    # Filter the list to only include tests we want to run.
86    tests_to_run = set(tests_to_run)
87    ratings = _get_ratings_for_tests(tests_to_run)
88    prioritize = sorted(ratings, key=lambda x: -ratings[x])
89    return prioritize
90