xref: /aosp_15_r20/external/pytorch/benchmarks/instruction_counts/applications/ci.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""Collect instruction counts for continuous integration."""
2# mypy: ignore-errors
3import argparse
4import hashlib
5import json
6import time
7from typing import Dict, List, Union
8
9from core.expand import materialize
10from definitions.standard import BENCHMARKS
11from execution.runner import Runner
12from execution.work import WorkOrder
13
14
15REPEATS = 5
16TIMEOUT = 600  # Seconds
17RETRIES = 2
18
19VERSION = 0
20MD5 = "4d55e8abf881ad38bb617a96714c1296"
21
22
23def main(argv: List[str]) -> None:
24    parser = argparse.ArgumentParser()
25    parser.add_argument("--destination", type=str, default=None)
26    parser.add_argument("--subset", action="store_true")
27    args = parser.parse_args(argv)
28
29    t0 = int(time.time())
30    version = VERSION
31    benchmarks = materialize(BENCHMARKS)
32
33    # Useful for local development, since e2e time for the full suite is O(1 hour)
34    in_debug_mode = args.subset or args.destination is None
35    if args.subset:
36        version = -1
37        benchmarks = benchmarks[:10]
38
39    work_orders = tuple(
40        WorkOrder(label, autolabels, timer_args, timeout=TIMEOUT, retries=RETRIES)
41        for label, autolabels, timer_args in benchmarks * REPEATS
42    )
43
44    keys = tuple({str(work_order): None for work_order in work_orders}.keys())
45    md5 = hashlib.md5()
46    for key in keys:
47        md5.update(key.encode("utf-8"))
48
49    # Warn early, since collection takes a long time.
50    if md5.hexdigest() != MD5 and not args.subset:
51        version = -1
52        print(f"WARNING: Expected {MD5}, got {md5.hexdigest()} instead")
53
54    results = Runner(work_orders, cadence=30.0).run()
55
56    # TODO: Annotate with TypedDict when 3.8 is the minimum supported verson.
57    grouped_results: Dict[str, Dict[str, List[Union[float, int]]]] = {
58        key: {"times": [], "counts": []} for key in keys
59    }
60
61    for work_order, r in results.items():
62        key = str(work_order)
63        grouped_results[key]["times"].extend(r.wall_times)
64        grouped_results[key]["counts"].extend(r.instructions)
65
66    final_results = {
67        "version": version,
68        "md5": md5.hexdigest(),
69        "start_time": t0,
70        "end_time": int(time.time()),
71        "values": grouped_results,
72    }
73
74    if args.destination:
75        with open(args.destination, "w") as f:
76            json.dump(final_results, f)
77
78    if in_debug_mode:
79        result_str = json.dumps(final_results)
80        print(f"{result_str[:30]} ... {result_str[-30:]}\n")
81        import pdb
82
83        pdb.set_trace()
84