xref: /aosp_15_r20/external/pytorch/test/mkldnn_verbose.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport argparse
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerclass Module(torch.nn.Module):
7*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
8*da0073e9SAndroid Build Coastguard Worker        super().__init__()
9*da0073e9SAndroid Build Coastguard Worker        self.conv = torch.nn.Conv2d(1, 10, 5, 1)
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
12*da0073e9SAndroid Build Coastguard Worker        y = self.conv(x)
13*da0073e9SAndroid Build Coastguard Worker        return y
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workerdef run_model(level):
17*da0073e9SAndroid Build Coastguard Worker    m = Module().eval()
18*da0073e9SAndroid Build Coastguard Worker    d = torch.rand(1, 1, 112, 112)
19*da0073e9SAndroid Build Coastguard Worker    with torch.backends.mkldnn.verbose(level):
20*da0073e9SAndroid Build Coastguard Worker        m(d)
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
24*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser()
25*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--verbose-level", default=0, type=int)
26*da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
27*da0073e9SAndroid Build Coastguard Worker    try:
28*da0073e9SAndroid Build Coastguard Worker        run_model(args.verbose_level)
29*da0073e9SAndroid Build Coastguard Worker    except Exception as e:
30*da0073e9SAndroid Build Coastguard Worker        print(e)
31