1*da0073e9SAndroid Build Coastguard Workerfrom . import benchmark 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerclass ConvImplBench(benchmark.Benchmark): 5*da0073e9SAndroid Build Coastguard Worker def __init__(self, case, mode, device, dtype, kernel_size, N, iC, H, W, oC): 6*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device, dtype) 7*da0073e9SAndroid Build Coastguard Worker self.case = case 8*da0073e9SAndroid Build Coastguard Worker self.kernel_size = kernel_size 9*da0073e9SAndroid Build Coastguard Worker self.N = N 10*da0073e9SAndroid Build Coastguard Worker self.iC = iC 11*da0073e9SAndroid Build Coastguard Worker self.H = H 12*da0073e9SAndroid Build Coastguard Worker self.W = W 13*da0073e9SAndroid Build Coastguard Worker self.oC = oC 14*da0073e9SAndroid Build Coastguard Worker self.data = self.rand( 15*da0073e9SAndroid Build Coastguard Worker [N, iC, H, W], device=device, requires_grad=self.requires_grad 16*da0073e9SAndroid Build Coastguard Worker ) 17*da0073e9SAndroid Build Coastguard Worker if case == "conv": 18*da0073e9SAndroid Build Coastguard Worker self.groups = 1 19*da0073e9SAndroid Build Coastguard Worker elif case == "depthwise_conv": 20*da0073e9SAndroid Build Coastguard Worker self.groups = iC 21*da0073e9SAndroid Build Coastguard Worker else: 22*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"invalid case: {case}") 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker self.conv = self.conv2d_layer(iC, oC, kernel_size, groups=self.groups) 25*da0073e9SAndroid Build Coastguard Worker if device != "cpu": 26*da0073e9SAndroid Build Coastguard Worker self.to_device(self.conv, device) 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker def forward(self): 29*da0073e9SAndroid Build Coastguard Worker y = self.conv(self.data) 30*da0073e9SAndroid Build Coastguard Worker return y 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker def config(self): 33*da0073e9SAndroid Build Coastguard Worker return [self.kernel_size, self.N, self.iC, self.H, self.W, self.oC] 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker def memory_workload(self): 36*da0073e9SAndroid Build Coastguard Worker if self.mode == "fwd": 37*da0073e9SAndroid Build Coastguard Worker sol_count = {"i": 1, "o": 1, "k": 1} 38*da0073e9SAndroid Build Coastguard Worker algorithmic_count = {"i": 1, "o": 1, "k": 1} 39*da0073e9SAndroid Build Coastguard Worker else: 40*da0073e9SAndroid Build Coastguard Worker sol_count = {"i": 1 + 1, "o": 1 + 1, "k": 1 + 1} 41*da0073e9SAndroid Build Coastguard Worker algorithmic_count = {"i": 1 + (1 + 1), "o": 1 + (1 + 1), "k": 1 + (1 + 1)} 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker buffer_size = { 44*da0073e9SAndroid Build Coastguard Worker "i": self.N * self.iC * self.H * self.W, 45*da0073e9SAndroid Build Coastguard Worker "o": self.N * self.oC * self.H * self.W, 46*da0073e9SAndroid Build Coastguard Worker "k": self.oC 47*da0073e9SAndroid Build Coastguard Worker * (self.iC / self.groups) 48*da0073e9SAndroid Build Coastguard Worker * self.kernel_size 49*da0073e9SAndroid Build Coastguard Worker * self.kernel_size, 50*da0073e9SAndroid Build Coastguard Worker } 51*da0073e9SAndroid Build Coastguard Worker sol_size = 0 52*da0073e9SAndroid Build Coastguard Worker algorithmic_size = 0 53*da0073e9SAndroid Build Coastguard Worker for key in sol_count: 54*da0073e9SAndroid Build Coastguard Worker sol_size += buffer_size[key] * sol_count[key] 55*da0073e9SAndroid Build Coastguard Worker algorithmic_size += buffer_size[key] * algorithmic_count[key] 56*da0073e9SAndroid Build Coastguard Worker return {"sol": sol_size, "algorithmic": algorithmic_size} 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker def compute_workload(self): 59*da0073e9SAndroid Build Coastguard Worker if self.mode == "fwd": 60*da0073e9SAndroid Build Coastguard Worker count = 1 61*da0073e9SAndroid Build Coastguard Worker elif self.mode == "both": 62*da0073e9SAndroid Build Coastguard Worker count = 1 + (1 + 1) 63*da0073e9SAndroid Build Coastguard Worker else: 64*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"invalid mode: {self.mode}") 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker op_count = ( 67*da0073e9SAndroid Build Coastguard Worker self.N 68*da0073e9SAndroid Build Coastguard Worker * self.iC 69*da0073e9SAndroid Build Coastguard Worker / self.groups 70*da0073e9SAndroid Build Coastguard Worker * self.oC 71*da0073e9SAndroid Build Coastguard Worker * self.kernel_size 72*da0073e9SAndroid Build Coastguard Worker * self.kernel_size 73*da0073e9SAndroid Build Coastguard Worker * self.H 74*da0073e9SAndroid Build Coastguard Worker * self.W 75*da0073e9SAndroid Build Coastguard Worker ) 76*da0073e9SAndroid Build Coastguard Worker op_count *= 2 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker return op_count * count 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker @staticmethod 81*da0073e9SAndroid Build Coastguard Worker def default_configs(): 82*da0073e9SAndroid Build Coastguard Worker return [ 83*da0073e9SAndroid Build Coastguard Worker [3, 64, 32, 128, 128, 64], 84*da0073e9SAndroid Build Coastguard Worker ] 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Workerclass ConvBench(ConvImplBench): 88*da0073e9SAndroid Build Coastguard Worker def __init__(self, *args): 89*da0073e9SAndroid Build Coastguard Worker super().__init__("conv", *args) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker @staticmethod 92*da0073e9SAndroid Build Coastguard Worker def module(): 93*da0073e9SAndroid Build Coastguard Worker return "conv" 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Workerclass DepthwiseConvBench(ConvImplBench): 97*da0073e9SAndroid Build Coastguard Worker def __init__(self, *args): 98*da0073e9SAndroid Build Coastguard Worker super().__init__("depthwise_conv", *args) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker @staticmethod 101*da0073e9SAndroid Build Coastguard Worker def module(): 102*da0073e9SAndroid Build Coastguard Worker return "depthwise_conv" 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(ConvBench) 106*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(DepthwiseConvBench) 107