1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TemporaryFileName, TestCase 5*da0073e9SAndroid Build Coastguard Workerfrom torch.utils import ThroughputBenchmark 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerclass TwoLayerNet(torch.jit.ScriptModule): 9*da0073e9SAndroid Build Coastguard Worker def __init__(self, D_in, H, D_out): 10*da0073e9SAndroid Build Coastguard Worker super().__init__() 11*da0073e9SAndroid Build Coastguard Worker self.linear1 = torch.nn.Linear(D_in, H) 12*da0073e9SAndroid Build Coastguard Worker self.linear2 = torch.nn.Linear(2 * H, D_out) 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 15*da0073e9SAndroid Build Coastguard Worker def forward(self, x1, x2): 16*da0073e9SAndroid Build Coastguard Worker h1_relu = self.linear1(x1).clamp(min=0) 17*da0073e9SAndroid Build Coastguard Worker h2_relu = self.linear1(x2).clamp(min=0) 18*da0073e9SAndroid Build Coastguard Worker cat = torch.cat((h1_relu, h2_relu), 1) 19*da0073e9SAndroid Build Coastguard Worker y_pred = self.linear2(cat) 20*da0073e9SAndroid Build Coastguard Worker return y_pred 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Workerclass TwoLayerNetModule(torch.nn.Module): 24*da0073e9SAndroid Build Coastguard Worker def __init__(self, D_in, H, D_out): 25*da0073e9SAndroid Build Coastguard Worker super().__init__() 26*da0073e9SAndroid Build Coastguard Worker self.linear1 = torch.nn.Linear(D_in, H) 27*da0073e9SAndroid Build Coastguard Worker self.linear2 = torch.nn.Linear(2 * H, D_out) 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker def forward(self, x1, x2): 30*da0073e9SAndroid Build Coastguard Worker h1_relu = self.linear1(x1).clamp(min=0) 31*da0073e9SAndroid Build Coastguard Worker h2_relu = self.linear1(x2).clamp(min=0) 32*da0073e9SAndroid Build Coastguard Worker cat = torch.cat((h1_relu, h2_relu), 1) 33*da0073e9SAndroid Build Coastguard Worker y_pred = self.linear2(cat) 34*da0073e9SAndroid Build Coastguard Worker return y_pred 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Workerclass TestThroughputBenchmark(TestCase): 38*da0073e9SAndroid Build Coastguard Worker def linear_test(self, Module, profiler_output_path=""): 39*da0073e9SAndroid Build Coastguard Worker D_in = 10 40*da0073e9SAndroid Build Coastguard Worker H = 5 41*da0073e9SAndroid Build Coastguard Worker D_out = 15 42*da0073e9SAndroid Build Coastguard Worker B = 8 43*da0073e9SAndroid Build Coastguard Worker NUM_INPUTS = 2 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker module = Module(D_in, H, D_out) 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker inputs = [] 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker for i in range(NUM_INPUTS): 50*da0073e9SAndroid Build Coastguard Worker inputs.append([torch.randn(B, D_in), torch.randn(B, D_in)]) 51*da0073e9SAndroid Build Coastguard Worker bench = ThroughputBenchmark(module) 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker for input in inputs: 54*da0073e9SAndroid Build Coastguard Worker # can do both args and kwargs here 55*da0073e9SAndroid Build Coastguard Worker bench.add_input(input[0], x2=input[1]) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker for i in range(NUM_INPUTS): 58*da0073e9SAndroid Build Coastguard Worker # or just unpack the list of inputs 59*da0073e9SAndroid Build Coastguard Worker module_result = module(*inputs[i]) 60*da0073e9SAndroid Build Coastguard Worker bench_result = bench.run_once(*inputs[i]) 61*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(bench_result, module_result) 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker stats = bench.benchmark( 64*da0073e9SAndroid Build Coastguard Worker num_calling_threads=4, 65*da0073e9SAndroid Build Coastguard Worker num_warmup_iters=100, 66*da0073e9SAndroid Build Coastguard Worker num_iters=1000, 67*da0073e9SAndroid Build Coastguard Worker profiler_output_path=profiler_output_path, 68*da0073e9SAndroid Build Coastguard Worker ) 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker print(stats) 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker def test_script_module(self): 73*da0073e9SAndroid Build Coastguard Worker self.linear_test(TwoLayerNet) 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker def test_module(self): 76*da0073e9SAndroid Build Coastguard Worker self.linear_test(TwoLayerNetModule) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker def test_profiling(self): 79*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as fname: 80*da0073e9SAndroid Build Coastguard Worker self.linear_test(TwoLayerNetModule, profiler_output_path=fname) 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 84*da0073e9SAndroid Build Coastguard Worker run_tests() 85