xref: /aosp_15_r20/external/pytorch/scripts/compile_tests/passrate.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport argparse
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom common import (
4*da0073e9SAndroid Build Coastguard Worker    get_excluded_testcases,
5*da0073e9SAndroid Build Coastguard Worker    get_passed_testcases,
6*da0073e9SAndroid Build Coastguard Worker    get_testcases,
7*da0073e9SAndroid Build Coastguard Worker    key,
8*da0073e9SAndroid Build Coastguard Worker    open_test_results,
9*da0073e9SAndroid Build Coastguard Worker)
10*da0073e9SAndroid Build Coastguard Workerfrom download_reports import download_reports
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker"""
14*da0073e9SAndroid Build Coastguard WorkerUsage: passrate.py commit_sha
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard WorkerParses test reports to measure the passrate. The passrate is defined as:
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard WorkerA) Take the number of tests that pass under eager mode, excluding
19*da0073e9SAndroid Build Coastguard WorkerCUDA, OpInfo, and ModuleInfo tests
20*da0073e9SAndroid Build Coastguard WorkerB) Of those tests, count the number of tests that pass under Dynamo
21*da0073e9SAndroid Build Coastguard WorkerC) Take B/A.
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard WorkerYou'll need to provide the commit_sha for a commit on the main branch,
24*da0073e9SAndroid Build Coastguard Workerfrom which we will pull CI test results.
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard WorkerThis script requires the `gh` cli. You'll need to install it and then
27*da0073e9SAndroid Build Coastguard Workerauthenticate with it via `gh auth login` before using this script.
28*da0073e9SAndroid Build Coastguard Workerhttps://docs.github.com/en/github-cli/github-cli/quickstart
29*da0073e9SAndroid Build Coastguard Worker"""
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Workerdef testcases_by_time(xmls):
33*da0073e9SAndroid Build Coastguard Worker    testcases = get_testcases(xmls)
34*da0073e9SAndroid Build Coastguard Worker    testcases.sort(reverse=True, key=lambda x: float(x.attrib["time"]))
35*da0073e9SAndroid Build Coastguard Worker    return testcases
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Workerdef should_exclude(key):
39*da0073e9SAndroid Build Coastguard Worker    test_file = key.split("::")[0]
40*da0073e9SAndroid Build Coastguard Worker    # C++ tests
41*da0073e9SAndroid Build Coastguard Worker    if test_file == "UNKNOWN":
42*da0073e9SAndroid Build Coastguard Worker        return True
43*da0073e9SAndroid Build Coastguard Worker    # Policy: "pass rate" does not include inductor, export, or dynamo tests.
44*da0073e9SAndroid Build Coastguard Worker    return test_file.startswith(("inductor/", "export/", "dynamo/"))
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Workerdef compute_pass_rate(eager_dir, dynamo_dir):
48*da0073e9SAndroid Build Coastguard Worker    print("parsing xmls")
49*da0073e9SAndroid Build Coastguard Worker    eager_xmls = open_test_results(eager_dir)
50*da0073e9SAndroid Build Coastguard Worker    dynamo_xmls = open_test_results(dynamo_dir)
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker    print("computing pass rate")
53*da0073e9SAndroid Build Coastguard Worker    eager_passed = get_passed_testcases(eager_xmls)
54*da0073e9SAndroid Build Coastguard Worker    dynamo_passed = get_passed_testcases(dynamo_xmls)
55*da0073e9SAndroid Build Coastguard Worker    dynamo_pass_keys = {key(testcase) for testcase in dynamo_passed}
56*da0073e9SAndroid Build Coastguard Worker    dynamo_pass_keys = {key_ for key_ in dynamo_pass_keys if not should_exclude(key_)}
57*da0073e9SAndroid Build Coastguard Worker    tmp_eager_pass_keys = {key(testcase) for testcase in eager_passed}
58*da0073e9SAndroid Build Coastguard Worker    tmp_eager_pass_keys = {
59*da0073e9SAndroid Build Coastguard Worker        key_ for key_ in tmp_eager_pass_keys if not should_exclude(key_)
60*da0073e9SAndroid Build Coastguard Worker    }
61*da0073e9SAndroid Build Coastguard Worker    excluded = [key(t) for t in get_excluded_testcases(dynamo_xmls)]
62*da0073e9SAndroid Build Coastguard Worker    eager_pass_keys = tmp_eager_pass_keys - set(excluded)
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    subset = eager_pass_keys.intersection(dynamo_pass_keys)
65*da0073e9SAndroid Build Coastguard Worker    total_subset = len(subset)
66*da0073e9SAndroid Build Coastguard Worker    total_tests = len(eager_pass_keys)
67*da0073e9SAndroid Build Coastguard Worker    print("pass rate", total_subset / total_tests, total_subset, total_tests)
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker    dynamo_testcases = get_testcases(dynamo_xmls)
70*da0073e9SAndroid Build Coastguard Worker    tc = {key(t): t for t in dynamo_testcases}
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker    # Useful for debugging
73*da0073e9SAndroid Build Coastguard Worker    not_there_keys = set()
74*da0073e9SAndroid Build Coastguard Worker    for key_ in eager_pass_keys:
75*da0073e9SAndroid Build Coastguard Worker        if key_ not in tc:
76*da0073e9SAndroid Build Coastguard Worker            not_there_keys.add(key_)
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker    fail_keys = eager_pass_keys - subset
79*da0073e9SAndroid Build Coastguard Worker    return fail_keys
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
83*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(
84*da0073e9SAndroid Build Coastguard Worker        prog="passrate", description="Computes the Dynamo unittest pass rate"
85*da0073e9SAndroid Build Coastguard Worker    )
86*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
87*da0073e9SAndroid Build Coastguard Worker        "commit",
88*da0073e9SAndroid Build Coastguard Worker        help=(
89*da0073e9SAndroid Build Coastguard Worker            "The commit sha for the latest commit on a PR from which we will "
90*da0073e9SAndroid Build Coastguard Worker            "pull CI test results, e.g. 7e5f597aeeba30c390c05f7d316829b3798064a5"
91*da0073e9SAndroid Build Coastguard Worker        ),
92*da0073e9SAndroid Build Coastguard Worker    )
93*da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
94*da0073e9SAndroid Build Coastguard Worker    dynamo311, eager311 = download_reports(args.commit, ("dynamo311", "eager311"))
95*da0073e9SAndroid Build Coastguard Worker    compute_pass_rate(eager311, dynamo311)
96