xref: /aosp_15_r20/external/pytorch/tools/testing/target_determination/heuristics/llm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import json
4import os
5import re
6from collections import defaultdict
7from pathlib import Path
8from typing import Any
9
10from tools.stats.import_test_stats import ADDITIONAL_CI_FILES_FOLDER
11from tools.testing.target_determination.heuristics.interface import (
12    HeuristicInterface,
13    TestPrioritizations,
14)
15from tools.testing.target_determination.heuristics.utils import normalize_ratings
16from tools.testing.test_run import TestRun
17
18
19REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent.parent
20
21
22class LLM(HeuristicInterface):
23    def __init__(self, **kwargs: dict[str, Any]) -> None:
24        super().__init__(**kwargs)
25
26    def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations:
27        critical_tests = self.get_mappings()
28        filter_valid_tests = {
29            TestRun(test): score
30            for test, score in critical_tests.items()
31            if test in tests
32        }
33        normalized_scores = normalize_ratings(filter_valid_tests, 0.25)
34        return TestPrioritizations(tests, normalized_scores)
35
36    def get_mappings(self) -> dict[str, float]:
37        path = (
38            REPO_ROOT
39            / ADDITIONAL_CI_FILES_FOLDER
40            / "llm_results/mappings/indexer-files-gitdiff-output.json"
41        )
42        if not os.path.exists(path):
43            print(f"could not find path {path}")
44            return {}
45        with open(path) as f:
46            # Group by file
47            r = defaultdict(list)
48            for key, value in json.load(f).items():
49                re_match = re.match("(.*).py", key)
50                if re_match:
51                    file = re_match.group(1)
52                    r[file].append(value)
53            # Average the scores for each file
54            r = {file: sum(scores) / len(scores) for file, scores in r.items()}
55            return r
56