xref: /aosp_15_r20/external/pytorch/benchmarks/fastrnns/bench.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport argparse
2*da0073e9SAndroid Build Coastguard Workerimport copy
3*da0073e9SAndroid Build Coastguard Workerimport gc
4*da0073e9SAndroid Build Coastguard Workerimport json
5*da0073e9SAndroid Build Coastguard Workerimport sys
6*da0073e9SAndroid Build Coastguard Workerimport time
7*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport torch
10*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.profiler import record_function
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerfrom .fuser import set_fuser
13*da0073e9SAndroid Build Coastguard Workerfrom .runner import get_nn_runners
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard WorkerBenchResult = namedtuple(
17*da0073e9SAndroid Build Coastguard Worker    "BenchResult",
18*da0073e9SAndroid Build Coastguard Worker    [
19*da0073e9SAndroid Build Coastguard Worker        "name",
20*da0073e9SAndroid Build Coastguard Worker        "avg_fwd",
21*da0073e9SAndroid Build Coastguard Worker        "std_fwd",
22*da0073e9SAndroid Build Coastguard Worker        "info_fwd",
23*da0073e9SAndroid Build Coastguard Worker        "avg_bwd",
24*da0073e9SAndroid Build Coastguard Worker        "std_bwd",
25*da0073e9SAndroid Build Coastguard Worker        "info_bwd",
26*da0073e9SAndroid Build Coastguard Worker    ],
27*da0073e9SAndroid Build Coastguard Worker)
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Workerdef fit_str(string, colwidth=16):
31*da0073e9SAndroid Build Coastguard Worker    if len(string) < colwidth:
32*da0073e9SAndroid Build Coastguard Worker        return (colwidth - len(string)) * " " + string
33*da0073e9SAndroid Build Coastguard Worker    else:
34*da0073e9SAndroid Build Coastguard Worker        return string[:colwidth]
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Workerdef to_str(item):
38*da0073e9SAndroid Build Coastguard Worker    if isinstance(item, float):
39*da0073e9SAndroid Build Coastguard Worker        return f"{item:.4g}"
40*da0073e9SAndroid Build Coastguard Worker    return str(item)
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Workerdef print_header(colwidth=16, sep=" "):
44*da0073e9SAndroid Build Coastguard Worker    items = []
45*da0073e9SAndroid Build Coastguard Worker    for item in BenchResult._fields:
46*da0073e9SAndroid Build Coastguard Worker        items.append(fit_str(item))
47*da0073e9SAndroid Build Coastguard Worker    return sep.join(items)
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Workerdef pretty_print(benchresult, colwidth=16, sep=" "):
51*da0073e9SAndroid Build Coastguard Worker    items = []
52*da0073e9SAndroid Build Coastguard Worker    for thing in benchresult:
53*da0073e9SAndroid Build Coastguard Worker        items.append(fit_str(to_str(thing)))
54*da0073e9SAndroid Build Coastguard Worker    return sep.join(items)
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker# shim for torch.cuda.Event when running on cpu
58*da0073e9SAndroid Build Coastguard Workerclass Event:
59*da0073e9SAndroid Build Coastguard Worker    def __init__(self, enable_timing):
60*da0073e9SAndroid Build Coastguard Worker        pass
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker    def record(self):
63*da0073e9SAndroid Build Coastguard Worker        self.time = time.perf_counter()
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker    def elapsed_time(self, end_event):
66*da0073e9SAndroid Build Coastguard Worker        assert isinstance(end_event, Event)
67*da0073e9SAndroid Build Coastguard Worker        return end_event.time - self.time
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Workerdef trainbench(
71*da0073e9SAndroid Build Coastguard Worker    name,
72*da0073e9SAndroid Build Coastguard Worker    rnn_creator,
73*da0073e9SAndroid Build Coastguard Worker    nloops=100,
74*da0073e9SAndroid Build Coastguard Worker    warmup=10,
75*da0073e9SAndroid Build Coastguard Worker    seqLength=100,
76*da0073e9SAndroid Build Coastguard Worker    numLayers=1,
77*da0073e9SAndroid Build Coastguard Worker    inputSize=512,
78*da0073e9SAndroid Build Coastguard Worker    hiddenSize=512,
79*da0073e9SAndroid Build Coastguard Worker    miniBatch=64,
80*da0073e9SAndroid Build Coastguard Worker    device="cuda",
81*da0073e9SAndroid Build Coastguard Worker    seed=None,
82*da0073e9SAndroid Build Coastguard Worker):
83*da0073e9SAndroid Build Coastguard Worker    def train_batch(modeldef):
84*da0073e9SAndroid Build Coastguard Worker        # CUDA events for timing
85*da0073e9SAndroid Build Coastguard Worker        if device == "cuda":
86*da0073e9SAndroid Build Coastguard Worker            timer_class = torch.cuda.Event
87*da0073e9SAndroid Build Coastguard Worker        else:
88*da0073e9SAndroid Build Coastguard Worker            timer_class = Event
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker        fwd_start_event = timer_class(enable_timing=True)
91*da0073e9SAndroid Build Coastguard Worker        fwd_end_event = timer_class(enable_timing=True)
92*da0073e9SAndroid Build Coastguard Worker        bwd_start_event = timer_class(enable_timing=True)
93*da0073e9SAndroid Build Coastguard Worker        bwd_end_event = timer_class(enable_timing=True)
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker        gc.collect()
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker        fwd_start_event.record()
98*da0073e9SAndroid Build Coastguard Worker        with record_function("## forward ##"):
99*da0073e9SAndroid Build Coastguard Worker            forward_output = modeldef.forward(*modeldef.inputs)
100*da0073e9SAndroid Build Coastguard Worker        fwd_end_event.record()
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker        # XXX: Use if need to print something
103*da0073e9SAndroid Build Coastguard Worker        # print(modeldef.forward.graph_for(*modeldef.inputs))
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker        if modeldef.backward_setup is not None:
106*da0073e9SAndroid Build Coastguard Worker            backward_input = modeldef.backward_setup(forward_output)
107*da0073e9SAndroid Build Coastguard Worker        else:
108*da0073e9SAndroid Build Coastguard Worker            backward_input = forward_output
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker        gc.collect()
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker        bwd_start_event.record()
113*da0073e9SAndroid Build Coastguard Worker        if modeldef.backward is not None:
114*da0073e9SAndroid Build Coastguard Worker            modeldef.backward(*backward_input)
115*da0073e9SAndroid Build Coastguard Worker        bwd_end_event.record()
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker        if modeldef.backward is not None:
118*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
119*da0073e9SAndroid Build Coastguard Worker                for param in modeldef.params:
120*da0073e9SAndroid Build Coastguard Worker                    assert param.grad is not None
121*da0073e9SAndroid Build Coastguard Worker                    param.grad.zero_()
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker        if device == "cuda":
124*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker        fwd_time = fwd_start_event.elapsed_time(fwd_end_event)
127*da0073e9SAndroid Build Coastguard Worker        bwd_time = bwd_start_event.elapsed_time(bwd_end_event)
128*da0073e9SAndroid Build Coastguard Worker        return fwd_time, bwd_time
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker    creator_args = creator_args = {
131*da0073e9SAndroid Build Coastguard Worker        "seqLength": seqLength,
132*da0073e9SAndroid Build Coastguard Worker        "numLayers": numLayers,
133*da0073e9SAndroid Build Coastguard Worker        "inputSize": inputSize,
134*da0073e9SAndroid Build Coastguard Worker        "hiddenSize": hiddenSize,
135*da0073e9SAndroid Build Coastguard Worker        "miniBatch": miniBatch,
136*da0073e9SAndroid Build Coastguard Worker        "device": device,
137*da0073e9SAndroid Build Coastguard Worker        "seed": seed,
138*da0073e9SAndroid Build Coastguard Worker    }
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker    modeldef = rnn_creator(**creator_args)
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker    [train_batch(modeldef) for _ in range(warmup)]
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker    results = [train_batch(modeldef) for _ in range(nloops)]
145*da0073e9SAndroid Build Coastguard Worker    fwd_times, bwd_times = zip(*results)
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker    fwd_times = torch.tensor(fwd_times)
148*da0073e9SAndroid Build Coastguard Worker    bwd_times = torch.tensor(bwd_times)
149*da0073e9SAndroid Build Coastguard Worker    return BenchResult(
150*da0073e9SAndroid Build Coastguard Worker        name=name,
151*da0073e9SAndroid Build Coastguard Worker        avg_fwd=fwd_times.mean().item(),
152*da0073e9SAndroid Build Coastguard Worker        std_fwd=fwd_times.std().item(),
153*da0073e9SAndroid Build Coastguard Worker        info_fwd=fwd_times,
154*da0073e9SAndroid Build Coastguard Worker        avg_bwd=bwd_times.mean().item(),
155*da0073e9SAndroid Build Coastguard Worker        std_bwd=bwd_times.std().item(),
156*da0073e9SAndroid Build Coastguard Worker        info_bwd=bwd_times,
157*da0073e9SAndroid Build Coastguard Worker    )
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Workerdef print_stderr(*args, **kwargs):
161*da0073e9SAndroid Build Coastguard Worker    kwargs["file"] = sys.stderr
162*da0073e9SAndroid Build Coastguard Worker    return print(*args, **kwargs)
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Workerdef print_json_oss_format(results):
166*da0073e9SAndroid Build Coastguard Worker    oss_results = {}
167*da0073e9SAndroid Build Coastguard Worker    for group_name, group_val in results.items():
168*da0073e9SAndroid Build Coastguard Worker        oss_results[group_name] = {}
169*da0073e9SAndroid Build Coastguard Worker        for model_name, run_time in group_val.items():
170*da0073e9SAndroid Build Coastguard Worker            # Output for OSS
171*da0073e9SAndroid Build Coastguard Worker            oss_results[group_name][model_name] = run_time["avg"]
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker    print(json.dumps(oss_results))
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Workerdef print_json_pep_format(results):
177*da0073e9SAndroid Build Coastguard Worker    # print the AI-PEP format json string for each model
178*da0073e9SAndroid Build Coastguard Worker    for group_name, group_val in results.items():
179*da0073e9SAndroid Build Coastguard Worker        for model_name, run_time in group_val.items():
180*da0073e9SAndroid Build Coastguard Worker            # Output for AI-PEP
181*da0073e9SAndroid Build Coastguard Worker            num_iters = len(run_time["info"])
182*da0073e9SAndroid Build Coastguard Worker            info = run_time["info"].tolist()
183*da0073e9SAndroid Build Coastguard Worker            for i in range(num_iters):
184*da0073e9SAndroid Build Coastguard Worker                print(
185*da0073e9SAndroid Build Coastguard Worker                    "Caffe2Observer "
186*da0073e9SAndroid Build Coastguard Worker                    + json.dumps(
187*da0073e9SAndroid Build Coastguard Worker                        {
188*da0073e9SAndroid Build Coastguard Worker                            "type": "NET",
189*da0073e9SAndroid Build Coastguard Worker                            "metric": group_name + "-" + model_name,
190*da0073e9SAndroid Build Coastguard Worker                            "unit": "ms",
191*da0073e9SAndroid Build Coastguard Worker                            "value": str(info[i]),
192*da0073e9SAndroid Build Coastguard Worker                        }
193*da0073e9SAndroid Build Coastguard Worker                    )
194*da0073e9SAndroid Build Coastguard Worker                )
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Workerdef bench(rnn_runners, group_name, print_json=False, sep=" ", **params):
198*da0073e9SAndroid Build Coastguard Worker    print_stderr(print_header(sep=sep))
199*da0073e9SAndroid Build Coastguard Worker    results = {}
200*da0073e9SAndroid Build Coastguard Worker    for name, creator, context in rnn_runners:
201*da0073e9SAndroid Build Coastguard Worker        with context():
202*da0073e9SAndroid Build Coastguard Worker            try:
203*da0073e9SAndroid Build Coastguard Worker                result = trainbench(name, creator, **params)
204*da0073e9SAndroid Build Coastguard Worker                # Replace the value of info_fwd and info_bwd to None
205*da0073e9SAndroid Build Coastguard Worker                result_with_no_info = result._replace(info_fwd="None", info_bwd="None")
206*da0073e9SAndroid Build Coastguard Worker                print_stderr(pretty_print(result_with_no_info, sep=sep))
207*da0073e9SAndroid Build Coastguard Worker                results[name] = result
208*da0073e9SAndroid Build Coastguard Worker            except Exception as e:
209*da0073e9SAndroid Build Coastguard Worker                if not print_json:
210*da0073e9SAndroid Build Coastguard Worker                    raise
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker    return {
213*da0073e9SAndroid Build Coastguard Worker        group_name: {
214*da0073e9SAndroid Build Coastguard Worker            k: {"avg": v.avg_fwd, "std": v.std_fwd, "info": v.info_fwd}
215*da0073e9SAndroid Build Coastguard Worker            for k, v in results.items()
216*da0073e9SAndroid Build Coastguard Worker        },
217*da0073e9SAndroid Build Coastguard Worker        group_name
218*da0073e9SAndroid Build Coastguard Worker        + "-backward": {
219*da0073e9SAndroid Build Coastguard Worker            k: {"avg": v.avg_bwd, "std": v.std_bwd, "info": v.info_bwd}
220*da0073e9SAndroid Build Coastguard Worker            for k, v in results.items()
221*da0073e9SAndroid Build Coastguard Worker        },
222*da0073e9SAndroid Build Coastguard Worker    }
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Workerdef bench_group(model_list, bench_name, bench_group, bench_args):
226*da0073e9SAndroid Build Coastguard Worker    print_stderr(f"Benchmarking {bench_name}s...")
227*da0073e9SAndroid Build Coastguard Worker    nn_results = bench(get_nn_runners(*model_list), bench_group, **bench_args)
228*da0073e9SAndroid Build Coastguard Worker    print_stderr("")
229*da0073e9SAndroid Build Coastguard Worker    return nn_results
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
233*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(description="Profile RNNs")
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker    # groups help control which test group you want to run
236*da0073e9SAndroid Build Coastguard Worker    # if you only want to run one/two benchmark, run it with
237*da0073e9SAndroid Build Coastguard Worker    # e.g: python -m fastrnns.bench --rnns jit and --group rnns
238*da0073e9SAndroid Build Coastguard Worker    default_groups = ["cnns", "rnns"]
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--seqLength", default="100", type=int)
241*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--numLayers", default="1", type=int)
242*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--inputSize", default="512", type=int)
243*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--hiddenSize", default="512", type=int)
244*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--miniBatch", default="64", type=int)
245*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--warmup", default="10", type=int)
246*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--nloops", default="100", type=int)
247*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--device", default="cuda", type=str)
248*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
249*da0073e9SAndroid Build Coastguard Worker        "--variable-lstms",
250*da0073e9SAndroid Build Coastguard Worker        "--variable_lstms",
251*da0073e9SAndroid Build Coastguard Worker        action="store_true",
252*da0073e9SAndroid Build Coastguard Worker        help="Also benchmark variable sequence length lstms "
253*da0073e9SAndroid Build Coastguard Worker        "Note that some of these run really slowly "
254*da0073e9SAndroid Build Coastguard Worker        "and that the `seqLength` flag will be ignored.",
255*da0073e9SAndroid Build Coastguard Worker    )
256*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--sep", default=" ", type=str)
257*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--print-json", nargs="?", default=None, const="oss")
258*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--rnns", nargs="*", help="What to run. cudnn, aten, jit, etc")
259*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
260*da0073e9SAndroid Build Coastguard Worker        "--cnns", nargs="*", help="What to run. resnet18, resnet18_jit, resnet50, etc"
261*da0073e9SAndroid Build Coastguard Worker    )
262*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
263*da0073e9SAndroid Build Coastguard Worker        "--group",
264*da0073e9SAndroid Build Coastguard Worker        nargs="*",
265*da0073e9SAndroid Build Coastguard Worker        default=default_groups,
266*da0073e9SAndroid Build Coastguard Worker        help="Which group to run. cnns, rnns, etc.",
267*da0073e9SAndroid Build Coastguard Worker    )
268*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
269*da0073e9SAndroid Build Coastguard Worker        "--fuser",
270*da0073e9SAndroid Build Coastguard Worker        default="te",
271*da0073e9SAndroid Build Coastguard Worker        type=str,
272*da0073e9SAndroid Build Coastguard Worker        help="The fuser backend to use. One of: te, old, or none",
273*da0073e9SAndroid Build Coastguard Worker    )
274*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
275*da0073e9SAndroid Build Coastguard Worker        "--executor",
276*da0073e9SAndroid Build Coastguard Worker        default=None,
277*da0073e9SAndroid Build Coastguard Worker        type=str,
278*da0073e9SAndroid Build Coastguard Worker        help="The executor to use. One of: legacy, simple, profiling",
279*da0073e9SAndroid Build Coastguard Worker    )
280*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
281*da0073e9SAndroid Build Coastguard Worker        "--cuda-pointwise-loop-level",
282*da0073e9SAndroid Build Coastguard Worker        "--cuda_pointwise_loop_level",
283*da0073e9SAndroid Build Coastguard Worker        default=None,
284*da0073e9SAndroid Build Coastguard Worker        type=int,
285*da0073e9SAndroid Build Coastguard Worker    )
286*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
287*da0073e9SAndroid Build Coastguard Worker        "--cuda-pointwise-block-count",
288*da0073e9SAndroid Build Coastguard Worker        "--cuda_pointwise_block_count",
289*da0073e9SAndroid Build Coastguard Worker        default=None,
290*da0073e9SAndroid Build Coastguard Worker        type=int,
291*da0073e9SAndroid Build Coastguard Worker    )
292*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
293*da0073e9SAndroid Build Coastguard Worker        "--cuda-pointwise-block-size",
294*da0073e9SAndroid Build Coastguard Worker        "--cuda_pointwise_block_size",
295*da0073e9SAndroid Build Coastguard Worker        default=None,
296*da0073e9SAndroid Build Coastguard Worker        type=int,
297*da0073e9SAndroid Build Coastguard Worker    )
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
300*da0073e9SAndroid Build Coastguard Worker    set_fuser(args.fuser, args.executor)
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker    if args.cuda_pointwise_loop_level:
303*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_te_cuda_pointwise_loop_levels(args.cuda_pointwise_loop_level)
304*da0073e9SAndroid Build Coastguard Worker    if args.cuda_pointwise_block_count:
305*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_te_cuda_pointwise_block_count(args.cuda_pointwise_block_count)
306*da0073e9SAndroid Build Coastguard Worker    if args.cuda_pointwise_block_size:
307*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_te_cuda_pointwise_block_size(args.cuda_pointwise_block_size)
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker    rnns = args.rnns or [
310*da0073e9SAndroid Build Coastguard Worker        "cudnn",
311*da0073e9SAndroid Build Coastguard Worker        "aten",
312*da0073e9SAndroid Build Coastguard Worker        "jit",
313*da0073e9SAndroid Build Coastguard Worker        "jit_premul",
314*da0073e9SAndroid Build Coastguard Worker        "jit_premul_bias",
315*da0073e9SAndroid Build Coastguard Worker        "jit_simple",
316*da0073e9SAndroid Build Coastguard Worker        "jit_multilayer",
317*da0073e9SAndroid Build Coastguard Worker        "py",
318*da0073e9SAndroid Build Coastguard Worker    ]
319*da0073e9SAndroid Build Coastguard Worker    cnns = args.cnns or ["resnet18", "resnet18_jit", "resnet50", "resnet50_jit"]
320*da0073e9SAndroid Build Coastguard Worker    # TODO: Maybe add a separate section for the layernorm/dropout lstms
321*da0073e9SAndroid Build Coastguard Worker    # 'cudnn_layernorm', jit_layernorm', 'jit_layernom_decom',
322*da0073e9SAndroid Build Coastguard Worker    # 'jit', 'jit_dropout', 'cudnn_dropout'
323*da0073e9SAndroid Build Coastguard Worker    vlrnns = ["vl_cudnn", "vl_jit", "vl_py"]
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker    if args.print_json:
326*da0073e9SAndroid Build Coastguard Worker        print_stderr = lambda *args, **kwargs: None  # noqa: E731,F811
327*da0073e9SAndroid Build Coastguard Worker    print_stderr(args)
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker    bench_args = copy.deepcopy(vars(args))
330*da0073e9SAndroid Build Coastguard Worker    should_bench_varlen_lstms = args.variable_lstms
331*da0073e9SAndroid Build Coastguard Worker    del bench_args["group"]
332*da0073e9SAndroid Build Coastguard Worker    del bench_args["rnns"]
333*da0073e9SAndroid Build Coastguard Worker    del bench_args["cnns"]
334*da0073e9SAndroid Build Coastguard Worker    del bench_args["variable_lstms"]
335*da0073e9SAndroid Build Coastguard Worker    del bench_args["fuser"]
336*da0073e9SAndroid Build Coastguard Worker    del bench_args["executor"]
337*da0073e9SAndroid Build Coastguard Worker    del bench_args["cuda_pointwise_loop_level"]
338*da0073e9SAndroid Build Coastguard Worker    del bench_args["cuda_pointwise_block_count"]
339*da0073e9SAndroid Build Coastguard Worker    del bench_args["cuda_pointwise_block_size"]
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker    results = {}
342*da0073e9SAndroid Build Coastguard Worker    if should_bench_varlen_lstms:
343*da0073e9SAndroid Build Coastguard Worker        if args.nloops + args.warmup > 30:
344*da0073e9SAndroid Build Coastguard Worker            print_stderr(
345*da0073e9SAndroid Build Coastguard Worker                "WARNING: some of the variable sequence length lstms are "
346*da0073e9SAndroid Build Coastguard Worker                "very unoptimized and therefore take forever to run."
347*da0073e9SAndroid Build Coastguard Worker            )
348*da0073e9SAndroid Build Coastguard Worker        results.update(
349*da0073e9SAndroid Build Coastguard Worker            bench_group(vlrnns, "variable-length sequence LSTM", "vl_lstm", bench_args)
350*da0073e9SAndroid Build Coastguard Worker        )
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker    if "rnns" in args.group:
353*da0073e9SAndroid Build Coastguard Worker        results.update(bench_group(rnns, "LSTM", "lstm", bench_args))
354*da0073e9SAndroid Build Coastguard Worker    if "cnns" in args.group:
355*da0073e9SAndroid Build Coastguard Worker        results.update(bench_group(cnns, "ResNet", "resnet", bench_args))
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker    if args.print_json == "oss":
358*da0073e9SAndroid Build Coastguard Worker        print_json_oss_format(results)
359*da0073e9SAndroid Build Coastguard Worker    elif args.print_json == "pep":
360*da0073e9SAndroid Build Coastguard Worker        print_json_pep_format(results)
361