xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/scripts/bisect.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import subprocess
4
5import click
6
7
8def test(cmd, limit):
9    print(f"Testing PYTORCH_JIT_OPT_LIMIT=tensorexpr_fuser={limit} {cmd}")
10    p = subprocess.run(
11        f"PYTORCH_JIT_OPT_LIMIT=tensorexpr_fuser={limit} {cmd}",
12        shell=True,
13        capture_output=True,
14        encoding="utf-8",
15        check=False,
16    )
17    print(p.stdout)
18    f = "INTERNAL ASSERT FAILED"
19    if f in p.stdout or f in p.stderr:
20        print("skip")
21        return -1
22    if p.returncode == 0:
23        print("good")
24        return 1
25    print("bad")
26    return 0
27
28
29@click.command()
30@click.option("--cmd")
31def bisect(cmd):
32    last_good = 0
33    first_bad = 10000
34    skips = set()
35
36    # Test if there are any unskipped commits in (last_good, first_bad)
37    def keep_going():
38        for limit in range(last_good + 1, first_bad):
39            if limit not in skips:
40                return True
41        return False
42
43    while keep_going():
44        test_limit = test_mid = (last_good + first_bad) // 2
45        val = -1
46
47        # Scan forward from mid towards bad.
48        while test_limit <= first_bad and val == -1:
49            val = test(cmd, test_limit)
50            if val == -1:
51                skips.add(test_limit)
52                test_limit = test_limit + 1
53
54        # If everything in [mid, bad] skipped, scan back towards good.
55        if val == -1:
56            test_limit = test_mid - 1
57            while test_limit >= last_good and val == -1:
58                val = test(cmd, test_limit)
59                if val == -1:
60                    skips.add(test_limit)
61                    test_limit = test_limit - 1
62
63        if val == 0:
64            first_bad = test_limit
65        elif val == 1:
66            last_good = test_limit
67
68    print(f"last good: {last_good}, first bad: {first_bad}")
69
70
71if __name__ == "__main__":
72    bisect()
73