xref: /aosp_15_r20/external/pytorch/scripts/jit/log_extract.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport argparse
2*da0073e9SAndroid Build Coastguard Workerimport functools
3*da0073e9SAndroid Build Coastguard Workerimport traceback
4*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, List, Optional, Tuple
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.jit.log_extract import (
7*da0073e9SAndroid Build Coastguard Worker    extract_ir,
8*da0073e9SAndroid Build Coastguard Worker    load_graph_and_inputs,
9*da0073e9SAndroid Build Coastguard Worker    run_baseline_no_fusion,
10*da0073e9SAndroid Build Coastguard Worker    run_nnc,
11*da0073e9SAndroid Build Coastguard Worker    run_nvfuser,
12*da0073e9SAndroid Build Coastguard Worker)
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker"""
16*da0073e9SAndroid Build Coastguard WorkerUsage:
17*da0073e9SAndroid Build Coastguard Worker1. Run your script and pipe into a log file
18*da0073e9SAndroid Build Coastguard Worker  PYTORCH_JIT_LOG_LEVEL=">>graph_fuser" python3 my_test.py &> log.txt
19*da0073e9SAndroid Build Coastguard Worker2. Run log_extract:
20*da0073e9SAndroid Build Coastguard Worker  log_extract.py log.txt --nvfuser --nnc-dynamic --nnc-static
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard WorkerYou can also extract the list of extracted IR:
23*da0073e9SAndroid Build Coastguard Worker  log_extract.py log.txt --output
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard WorkerPassing in --graphs 0 2 will only run graphs 0 and 2
26*da0073e9SAndroid Build Coastguard Worker"""
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Workerdef test_runners(
30*da0073e9SAndroid Build Coastguard Worker    graphs: List[str],
31*da0073e9SAndroid Build Coastguard Worker    runners: List[Tuple[str, Callable]],
32*da0073e9SAndroid Build Coastguard Worker    graph_set: Optional[List[int]],
33*da0073e9SAndroid Build Coastguard Worker):
34*da0073e9SAndroid Build Coastguard Worker    for i, ir in enumerate(graphs):
35*da0073e9SAndroid Build Coastguard Worker        _, inputs = load_graph_and_inputs(ir)
36*da0073e9SAndroid Build Coastguard Worker        if graph_set and i not in graph_set:
37*da0073e9SAndroid Build Coastguard Worker            continue
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker        print(f"Running Graph {i}")
40*da0073e9SAndroid Build Coastguard Worker        prev_result = None
41*da0073e9SAndroid Build Coastguard Worker        prev_runner_name = None
42*da0073e9SAndroid Build Coastguard Worker        for runner in runners:
43*da0073e9SAndroid Build Coastguard Worker            runner_name, runner_fn = runner
44*da0073e9SAndroid Build Coastguard Worker            try:
45*da0073e9SAndroid Build Coastguard Worker                result = runner_fn(ir, inputs)
46*da0073e9SAndroid Build Coastguard Worker                if prev_result:
47*da0073e9SAndroid Build Coastguard Worker                    improvement = (prev_result / result - 1) * 100
48*da0073e9SAndroid Build Coastguard Worker                    print(
49*da0073e9SAndroid Build Coastguard Worker                        f"{runner_name} : {result:.6f} ms improvement over {prev_runner_name}: improvement: {improvement:.2f}%"
50*da0073e9SAndroid Build Coastguard Worker                    )
51*da0073e9SAndroid Build Coastguard Worker                else:
52*da0073e9SAndroid Build Coastguard Worker                    print(f"{runner_name} : {result:.6f} ms")
53*da0073e9SAndroid Build Coastguard Worker                prev_result = result
54*da0073e9SAndroid Build Coastguard Worker                prev_runner_name = runner_name
55*da0073e9SAndroid Build Coastguard Worker            except RuntimeError:
56*da0073e9SAndroid Build Coastguard Worker                print(f"  Graph {i} failed for {runner_name} :", traceback.format_exc())
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Workerdef run():
60*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(
61*da0073e9SAndroid Build Coastguard Worker        description="Extracts torchscript IR from log files and, optionally, benchmarks it or outputs the IR"
62*da0073e9SAndroid Build Coastguard Worker    )
63*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("filename", help="Filename of log file")
64*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
65*da0073e9SAndroid Build Coastguard Worker        "--nvfuser", dest="nvfuser", action="store_true", help="benchmark nvfuser"
66*da0073e9SAndroid Build Coastguard Worker    )
67*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
68*da0073e9SAndroid Build Coastguard Worker        "--no-nvfuser",
69*da0073e9SAndroid Build Coastguard Worker        dest="nvfuser",
70*da0073e9SAndroid Build Coastguard Worker        action="store_false",
71*da0073e9SAndroid Build Coastguard Worker        help="DON'T benchmark nvfuser",
72*da0073e9SAndroid Build Coastguard Worker    )
73*da0073e9SAndroid Build Coastguard Worker    parser.set_defaults(nvfuser=False)
74*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
75*da0073e9SAndroid Build Coastguard Worker        "--nnc-static",
76*da0073e9SAndroid Build Coastguard Worker        dest="nnc_static",
77*da0073e9SAndroid Build Coastguard Worker        action="store_true",
78*da0073e9SAndroid Build Coastguard Worker        help="benchmark nnc static",
79*da0073e9SAndroid Build Coastguard Worker    )
80*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
81*da0073e9SAndroid Build Coastguard Worker        "--no-nnc-static",
82*da0073e9SAndroid Build Coastguard Worker        dest="nnc_static",
83*da0073e9SAndroid Build Coastguard Worker        action="store_false",
84*da0073e9SAndroid Build Coastguard Worker        help="DON'T benchmark nnc static",
85*da0073e9SAndroid Build Coastguard Worker    )
86*da0073e9SAndroid Build Coastguard Worker    parser.set_defaults(nnc_static=False)
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
89*da0073e9SAndroid Build Coastguard Worker        "--nnc-dynamic",
90*da0073e9SAndroid Build Coastguard Worker        dest="nnc_dynamic",
91*da0073e9SAndroid Build Coastguard Worker        action="store_true",
92*da0073e9SAndroid Build Coastguard Worker        help="nnc with dynamic shapes",
93*da0073e9SAndroid Build Coastguard Worker    )
94*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
95*da0073e9SAndroid Build Coastguard Worker        "--no-nnc-dynamic",
96*da0073e9SAndroid Build Coastguard Worker        dest="nnc_dynamic",
97*da0073e9SAndroid Build Coastguard Worker        action="store_false",
98*da0073e9SAndroid Build Coastguard Worker        help="DONT't benchmark nnc with dynamic shapes",
99*da0073e9SAndroid Build Coastguard Worker    )
100*da0073e9SAndroid Build Coastguard Worker    parser.set_defaults(nnc_dynamic=False)
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
103*da0073e9SAndroid Build Coastguard Worker        "--baseline", dest="baseline", action="store_true", help="benchmark baseline"
104*da0073e9SAndroid Build Coastguard Worker    )
105*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
106*da0073e9SAndroid Build Coastguard Worker        "--no-baseline",
107*da0073e9SAndroid Build Coastguard Worker        dest="baseline",
108*da0073e9SAndroid Build Coastguard Worker        action="store_false",
109*da0073e9SAndroid Build Coastguard Worker        help="DON'T benchmark baseline",
110*da0073e9SAndroid Build Coastguard Worker    )
111*da0073e9SAndroid Build Coastguard Worker    parser.set_defaults(baseline=False)
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
114*da0073e9SAndroid Build Coastguard Worker        "--output", dest="output", action="store_true", help="Output graph IR"
115*da0073e9SAndroid Build Coastguard Worker    )
116*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
117*da0073e9SAndroid Build Coastguard Worker        "--no-output", dest="output", action="store_false", help="DON'T output graph IR"
118*da0073e9SAndroid Build Coastguard Worker    )
119*da0073e9SAndroid Build Coastguard Worker    parser.set_defaults(output=False)
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
122*da0073e9SAndroid Build Coastguard Worker        "--graphs", nargs="+", type=int, help="Run only specified graph indices"
123*da0073e9SAndroid Build Coastguard Worker    )
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
126*da0073e9SAndroid Build Coastguard Worker    graphs = extract_ir(args.filename)
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker    graph_set = args.graphs
129*da0073e9SAndroid Build Coastguard Worker    graph_set = graph_set if graph_set else None
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    options = []
132*da0073e9SAndroid Build Coastguard Worker    if args.baseline:
133*da0073e9SAndroid Build Coastguard Worker        options.append(("Baseline no fusion", run_baseline_no_fusion))
134*da0073e9SAndroid Build Coastguard Worker    if args.nnc_dynamic:
135*da0073e9SAndroid Build Coastguard Worker        options.append(("NNC Dynamic", functools.partial(run_nnc, dynamic=True)))
136*da0073e9SAndroid Build Coastguard Worker    if args.nnc_static:
137*da0073e9SAndroid Build Coastguard Worker        options.append(("NNC Static", functools.partial(run_nnc, dynamic=False)))
138*da0073e9SAndroid Build Coastguard Worker    if args.nvfuser:
139*da0073e9SAndroid Build Coastguard Worker        options.append(("NVFuser", run_nvfuser))
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker    test_runners(graphs, options, graph_set)
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker    if args.output:
144*da0073e9SAndroid Build Coastguard Worker        quoted = []
145*da0073e9SAndroid Build Coastguard Worker        for i, ir in enumerate(graphs):
146*da0073e9SAndroid Build Coastguard Worker            if graph_set and i not in graph_set:
147*da0073e9SAndroid Build Coastguard Worker                continue
148*da0073e9SAndroid Build Coastguard Worker            quoted.append('"""' + ir + '"""')
149*da0073e9SAndroid Build Coastguard Worker        print("[" + ", ".join(quoted) + "]")
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
153*da0073e9SAndroid Build Coastguard Worker    run()
154