xref: /aosp_15_r20/external/pytorch/benchmarks/tensorexpr/conv.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom . import benchmark
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerclass ConvImplBench(benchmark.Benchmark):
5*da0073e9SAndroid Build Coastguard Worker    def __init__(self, case, mode, device, dtype, kernel_size, N, iC, H, W, oC):
6*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype)
7*da0073e9SAndroid Build Coastguard Worker        self.case = case
8*da0073e9SAndroid Build Coastguard Worker        self.kernel_size = kernel_size
9*da0073e9SAndroid Build Coastguard Worker        self.N = N
10*da0073e9SAndroid Build Coastguard Worker        self.iC = iC
11*da0073e9SAndroid Build Coastguard Worker        self.H = H
12*da0073e9SAndroid Build Coastguard Worker        self.W = W
13*da0073e9SAndroid Build Coastguard Worker        self.oC = oC
14*da0073e9SAndroid Build Coastguard Worker        self.data = self.rand(
15*da0073e9SAndroid Build Coastguard Worker            [N, iC, H, W], device=device, requires_grad=self.requires_grad
16*da0073e9SAndroid Build Coastguard Worker        )
17*da0073e9SAndroid Build Coastguard Worker        if case == "conv":
18*da0073e9SAndroid Build Coastguard Worker            self.groups = 1
19*da0073e9SAndroid Build Coastguard Worker        elif case == "depthwise_conv":
20*da0073e9SAndroid Build Coastguard Worker            self.groups = iC
21*da0073e9SAndroid Build Coastguard Worker        else:
22*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"invalid case: {case}")
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker        self.conv = self.conv2d_layer(iC, oC, kernel_size, groups=self.groups)
25*da0073e9SAndroid Build Coastguard Worker        if device != "cpu":
26*da0073e9SAndroid Build Coastguard Worker            self.to_device(self.conv, device)
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker    def forward(self):
29*da0073e9SAndroid Build Coastguard Worker        y = self.conv(self.data)
30*da0073e9SAndroid Build Coastguard Worker        return y
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker    def config(self):
33*da0073e9SAndroid Build Coastguard Worker        return [self.kernel_size, self.N, self.iC, self.H, self.W, self.oC]
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker    def memory_workload(self):
36*da0073e9SAndroid Build Coastguard Worker        if self.mode == "fwd":
37*da0073e9SAndroid Build Coastguard Worker            sol_count = {"i": 1, "o": 1, "k": 1}
38*da0073e9SAndroid Build Coastguard Worker            algorithmic_count = {"i": 1, "o": 1, "k": 1}
39*da0073e9SAndroid Build Coastguard Worker        else:
40*da0073e9SAndroid Build Coastguard Worker            sol_count = {"i": 1 + 1, "o": 1 + 1, "k": 1 + 1}
41*da0073e9SAndroid Build Coastguard Worker            algorithmic_count = {"i": 1 + (1 + 1), "o": 1 + (1 + 1), "k": 1 + (1 + 1)}
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker        buffer_size = {
44*da0073e9SAndroid Build Coastguard Worker            "i": self.N * self.iC * self.H * self.W,
45*da0073e9SAndroid Build Coastguard Worker            "o": self.N * self.oC * self.H * self.W,
46*da0073e9SAndroid Build Coastguard Worker            "k": self.oC
47*da0073e9SAndroid Build Coastguard Worker            * (self.iC / self.groups)
48*da0073e9SAndroid Build Coastguard Worker            * self.kernel_size
49*da0073e9SAndroid Build Coastguard Worker            * self.kernel_size,
50*da0073e9SAndroid Build Coastguard Worker        }
51*da0073e9SAndroid Build Coastguard Worker        sol_size = 0
52*da0073e9SAndroid Build Coastguard Worker        algorithmic_size = 0
53*da0073e9SAndroid Build Coastguard Worker        for key in sol_count:
54*da0073e9SAndroid Build Coastguard Worker            sol_size += buffer_size[key] * sol_count[key]
55*da0073e9SAndroid Build Coastguard Worker            algorithmic_size += buffer_size[key] * algorithmic_count[key]
56*da0073e9SAndroid Build Coastguard Worker        return {"sol": sol_size, "algorithmic": algorithmic_size}
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker    def compute_workload(self):
59*da0073e9SAndroid Build Coastguard Worker        if self.mode == "fwd":
60*da0073e9SAndroid Build Coastguard Worker            count = 1
61*da0073e9SAndroid Build Coastguard Worker        elif self.mode == "both":
62*da0073e9SAndroid Build Coastguard Worker            count = 1 + (1 + 1)
63*da0073e9SAndroid Build Coastguard Worker        else:
64*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"invalid mode: {self.mode}")
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker        op_count = (
67*da0073e9SAndroid Build Coastguard Worker            self.N
68*da0073e9SAndroid Build Coastguard Worker            * self.iC
69*da0073e9SAndroid Build Coastguard Worker            / self.groups
70*da0073e9SAndroid Build Coastguard Worker            * self.oC
71*da0073e9SAndroid Build Coastguard Worker            * self.kernel_size
72*da0073e9SAndroid Build Coastguard Worker            * self.kernel_size
73*da0073e9SAndroid Build Coastguard Worker            * self.H
74*da0073e9SAndroid Build Coastguard Worker            * self.W
75*da0073e9SAndroid Build Coastguard Worker        )
76*da0073e9SAndroid Build Coastguard Worker        op_count *= 2
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker        return op_count * count
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker    @staticmethod
81*da0073e9SAndroid Build Coastguard Worker    def default_configs():
82*da0073e9SAndroid Build Coastguard Worker        return [
83*da0073e9SAndroid Build Coastguard Worker            [3, 64, 32, 128, 128, 64],
84*da0073e9SAndroid Build Coastguard Worker        ]
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Workerclass ConvBench(ConvImplBench):
88*da0073e9SAndroid Build Coastguard Worker    def __init__(self, *args):
89*da0073e9SAndroid Build Coastguard Worker        super().__init__("conv", *args)
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker    @staticmethod
92*da0073e9SAndroid Build Coastguard Worker    def module():
93*da0073e9SAndroid Build Coastguard Worker        return "conv"
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Workerclass DepthwiseConvBench(ConvImplBench):
97*da0073e9SAndroid Build Coastguard Worker    def __init__(self, *args):
98*da0073e9SAndroid Build Coastguard Worker        super().__init__("depthwise_conv", *args)
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker    @staticmethod
101*da0073e9SAndroid Build Coastguard Worker    def module():
102*da0073e9SAndroid Build Coastguard Worker        return "depthwise_conv"
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(ConvBench)
106*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(DepthwiseConvBench)
107