1import argparse 2import time 3 4from common import SubTensor, SubWithTorchFunction, WithTorchFunction 5 6import torch 7 8 9NUM_REPEATS = 1000 10NUM_REPEAT_OF_REPEATS = 1000 11 12 13def bench(t1, t2): 14 bench_times = [] 15 for _ in range(NUM_REPEAT_OF_REPEATS): 16 time_start = time.time() 17 for _ in range(NUM_REPEATS): 18 torch.add(t1, t2) 19 bench_times.append(time.time() - time_start) 20 21 bench_time = float(torch.min(torch.tensor(bench_times))) / 1000 22 bench_std = float(torch.std(torch.tensor(bench_times))) / 1000 23 24 return bench_time, bench_std 25 26 27def main(): 28 global NUM_REPEATS 29 global NUM_REPEAT_OF_REPEATS 30 31 parser = argparse.ArgumentParser( 32 description="Run the __torch_function__ benchmarks." 33 ) 34 parser.add_argument( 35 "--nreps", 36 "-n", 37 type=int, 38 default=NUM_REPEATS, 39 help="The number of repeats for one measurement.", 40 ) 41 parser.add_argument( 42 "--nrepreps", 43 "-m", 44 type=int, 45 default=NUM_REPEAT_OF_REPEATS, 46 help="The number of measurements.", 47 ) 48 args = parser.parse_args() 49 50 NUM_REPEATS = args.nreps 51 NUM_REPEAT_OF_REPEATS = args.nrepreps 52 53 types = torch.tensor, SubTensor, WithTorchFunction, SubWithTorchFunction 54 55 for t in types: 56 tensor_1 = t([1.0]) 57 tensor_2 = t([2.0]) 58 59 bench_min, bench_std = bench(tensor_1, tensor_2) 60 print( 61 f"Type {t.__name__} had a minimum time of {10**6 * bench_min} us" 62 f" and a standard deviation of {(10**6) * bench_std} us." 63 ) 64 65 66if __name__ == "__main__": 67 main() 68