1*da0073e9SAndroid Build Coastguard Workerimport argparse 2*da0073e9SAndroid Build Coastguard Workerimport asyncio 3*da0073e9SAndroid Build Coastguard Workerimport os.path 4*da0073e9SAndroid Build Coastguard Workerimport subprocess 5*da0073e9SAndroid Build Coastguard Workerimport threading 6*da0073e9SAndroid Build Coastguard Workerimport time 7*da0073e9SAndroid Build Coastguard Workerfrom concurrent.futures import ThreadPoolExecutor 8*da0073e9SAndroid Build Coastguard Workerfrom queue import Empty 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerimport numpy as np 11*da0073e9SAndroid Build Coastguard Workerimport pandas as pd 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Workerimport torch 14*da0073e9SAndroid Build Coastguard Workerimport torch.multiprocessing as mp 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Workerclass FrontendWorker(mp.Process): 18*da0073e9SAndroid Build Coastguard Worker """ 19*da0073e9SAndroid Build Coastguard Worker This worker will send requests to a backend process, and measure the 20*da0073e9SAndroid Build Coastguard Worker throughput and latency of those requests as well as GPU utilization. 21*da0073e9SAndroid Build Coastguard Worker """ 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker def __init__( 24*da0073e9SAndroid Build Coastguard Worker self, 25*da0073e9SAndroid Build Coastguard Worker metrics_dict, 26*da0073e9SAndroid Build Coastguard Worker request_queue, 27*da0073e9SAndroid Build Coastguard Worker response_queue, 28*da0073e9SAndroid Build Coastguard Worker read_requests_event, 29*da0073e9SAndroid Build Coastguard Worker batch_size, 30*da0073e9SAndroid Build Coastguard Worker num_iters=10, 31*da0073e9SAndroid Build Coastguard Worker ): 32*da0073e9SAndroid Build Coastguard Worker super().__init__() 33*da0073e9SAndroid Build Coastguard Worker self.metrics_dict = metrics_dict 34*da0073e9SAndroid Build Coastguard Worker self.request_queue = request_queue 35*da0073e9SAndroid Build Coastguard Worker self.response_queue = response_queue 36*da0073e9SAndroid Build Coastguard Worker self.read_requests_event = read_requests_event 37*da0073e9SAndroid Build Coastguard Worker self.warmup_event = mp.Event() 38*da0073e9SAndroid Build Coastguard Worker self.batch_size = batch_size 39*da0073e9SAndroid Build Coastguard Worker self.num_iters = num_iters 40*da0073e9SAndroid Build Coastguard Worker self.poll_gpu = True 41*da0073e9SAndroid Build Coastguard Worker self.start_send_time = None 42*da0073e9SAndroid Build Coastguard Worker self.end_recv_time = None 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker def _run_metrics(self, metrics_lock): 45*da0073e9SAndroid Build Coastguard Worker """ 46*da0073e9SAndroid Build Coastguard Worker This function will poll the response queue until it has received all 47*da0073e9SAndroid Build Coastguard Worker responses. It records the startup latency, the average, max, min latency 48*da0073e9SAndroid Build Coastguard Worker as well as througput of requests. 49*da0073e9SAndroid Build Coastguard Worker """ 50*da0073e9SAndroid Build Coastguard Worker warmup_response_time = None 51*da0073e9SAndroid Build Coastguard Worker response_times = [] 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker for i in range(self.num_iters + 1): 54*da0073e9SAndroid Build Coastguard Worker response, request_time = self.response_queue.get() 55*da0073e9SAndroid Build Coastguard Worker if warmup_response_time is None: 56*da0073e9SAndroid Build Coastguard Worker self.warmup_event.set() 57*da0073e9SAndroid Build Coastguard Worker warmup_response_time = time.time() - request_time 58*da0073e9SAndroid Build Coastguard Worker else: 59*da0073e9SAndroid Build Coastguard Worker response_times.append(time.time() - request_time) 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker self.end_recv_time = time.time() 62*da0073e9SAndroid Build Coastguard Worker self.poll_gpu = False 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker response_times = np.array(response_times) 65*da0073e9SAndroid Build Coastguard Worker with metrics_lock: 66*da0073e9SAndroid Build Coastguard Worker self.metrics_dict["warmup_latency"] = warmup_response_time 67*da0073e9SAndroid Build Coastguard Worker self.metrics_dict["average_latency"] = response_times.mean() 68*da0073e9SAndroid Build Coastguard Worker self.metrics_dict["max_latency"] = response_times.max() 69*da0073e9SAndroid Build Coastguard Worker self.metrics_dict["min_latency"] = response_times.min() 70*da0073e9SAndroid Build Coastguard Worker self.metrics_dict["throughput"] = (self.num_iters * self.batch_size) / ( 71*da0073e9SAndroid Build Coastguard Worker self.end_recv_time - self.start_send_time 72*da0073e9SAndroid Build Coastguard Worker ) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker def _run_gpu_utilization(self, metrics_lock): 75*da0073e9SAndroid Build Coastguard Worker """ 76*da0073e9SAndroid Build Coastguard Worker This function will poll nvidia-smi for GPU utilization every 100ms to 77*da0073e9SAndroid Build Coastguard Worker record the average GPU utilization. 78*da0073e9SAndroid Build Coastguard Worker """ 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker def get_gpu_utilization(): 81*da0073e9SAndroid Build Coastguard Worker try: 82*da0073e9SAndroid Build Coastguard Worker nvidia_smi_output = subprocess.check_output( 83*da0073e9SAndroid Build Coastguard Worker [ 84*da0073e9SAndroid Build Coastguard Worker "nvidia-smi", 85*da0073e9SAndroid Build Coastguard Worker "--query-gpu=utilization.gpu", 86*da0073e9SAndroid Build Coastguard Worker "--id=0", 87*da0073e9SAndroid Build Coastguard Worker "--format=csv,noheader,nounits", 88*da0073e9SAndroid Build Coastguard Worker ] 89*da0073e9SAndroid Build Coastguard Worker ) 90*da0073e9SAndroid Build Coastguard Worker gpu_utilization = nvidia_smi_output.decode().strip() 91*da0073e9SAndroid Build Coastguard Worker return gpu_utilization 92*da0073e9SAndroid Build Coastguard Worker except subprocess.CalledProcessError: 93*da0073e9SAndroid Build Coastguard Worker return "N/A" 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker gpu_utilizations = [] 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker while self.poll_gpu: 98*da0073e9SAndroid Build Coastguard Worker gpu_utilization = get_gpu_utilization() 99*da0073e9SAndroid Build Coastguard Worker if gpu_utilization != "N/A": 100*da0073e9SAndroid Build Coastguard Worker gpu_utilizations.append(float(gpu_utilization)) 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker with metrics_lock: 103*da0073e9SAndroid Build Coastguard Worker self.metrics_dict["gpu_util"] = torch.tensor(gpu_utilizations).mean().item() 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker def _send_requests(self): 106*da0073e9SAndroid Build Coastguard Worker """ 107*da0073e9SAndroid Build Coastguard Worker This function will send one warmup request, and then num_iters requests 108*da0073e9SAndroid Build Coastguard Worker to the backend process. 109*da0073e9SAndroid Build Coastguard Worker """ 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker fake_data = torch.randn(self.batch_size, 3, 250, 250, requires_grad=False) 112*da0073e9SAndroid Build Coastguard Worker other_data = [ 113*da0073e9SAndroid Build Coastguard Worker torch.randn(self.batch_size, 3, 250, 250, requires_grad=False) 114*da0073e9SAndroid Build Coastguard Worker for i in range(self.num_iters) 115*da0073e9SAndroid Build Coastguard Worker ] 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker # Send one batch of warmup data 118*da0073e9SAndroid Build Coastguard Worker self.request_queue.put((fake_data, time.time())) 119*da0073e9SAndroid Build Coastguard Worker # Tell backend to poll queue for warmup request 120*da0073e9SAndroid Build Coastguard Worker self.read_requests_event.set() 121*da0073e9SAndroid Build Coastguard Worker self.warmup_event.wait() 122*da0073e9SAndroid Build Coastguard Worker # Tell backend to poll queue for rest of requests 123*da0073e9SAndroid Build Coastguard Worker self.read_requests_event.set() 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker # Send fake data 126*da0073e9SAndroid Build Coastguard Worker self.start_send_time = time.time() 127*da0073e9SAndroid Build Coastguard Worker for i in range(self.num_iters): 128*da0073e9SAndroid Build Coastguard Worker self.request_queue.put((other_data[i], time.time())) 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker def run(self): 131*da0073e9SAndroid Build Coastguard Worker # Lock for writing to metrics_dict 132*da0073e9SAndroid Build Coastguard Worker metrics_lock = threading.Lock() 133*da0073e9SAndroid Build Coastguard Worker requests_thread = threading.Thread(target=self._send_requests) 134*da0073e9SAndroid Build Coastguard Worker metrics_thread = threading.Thread( 135*da0073e9SAndroid Build Coastguard Worker target=self._run_metrics, args=(metrics_lock,) 136*da0073e9SAndroid Build Coastguard Worker ) 137*da0073e9SAndroid Build Coastguard Worker gpu_utilization_thread = threading.Thread( 138*da0073e9SAndroid Build Coastguard Worker target=self._run_gpu_utilization, args=(metrics_lock,) 139*da0073e9SAndroid Build Coastguard Worker ) 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker requests_thread.start() 142*da0073e9SAndroid Build Coastguard Worker metrics_thread.start() 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker # only start polling GPU utilization after the warmup request is complete 145*da0073e9SAndroid Build Coastguard Worker self.warmup_event.wait() 146*da0073e9SAndroid Build Coastguard Worker gpu_utilization_thread.start() 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker requests_thread.join() 149*da0073e9SAndroid Build Coastguard Worker metrics_thread.join() 150*da0073e9SAndroid Build Coastguard Worker gpu_utilization_thread.join() 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Workerclass BackendWorker: 154*da0073e9SAndroid Build Coastguard Worker """ 155*da0073e9SAndroid Build Coastguard Worker This worker will take tensors from the request queue, do some computation, 156*da0073e9SAndroid Build Coastguard Worker and then return the result back in the response queue. 157*da0073e9SAndroid Build Coastguard Worker """ 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker def __init__( 160*da0073e9SAndroid Build Coastguard Worker self, 161*da0073e9SAndroid Build Coastguard Worker metrics_dict, 162*da0073e9SAndroid Build Coastguard Worker request_queue, 163*da0073e9SAndroid Build Coastguard Worker response_queue, 164*da0073e9SAndroid Build Coastguard Worker read_requests_event, 165*da0073e9SAndroid Build Coastguard Worker batch_size, 166*da0073e9SAndroid Build Coastguard Worker num_workers, 167*da0073e9SAndroid Build Coastguard Worker model_dir=".", 168*da0073e9SAndroid Build Coastguard Worker compile_model=True, 169*da0073e9SAndroid Build Coastguard Worker ): 170*da0073e9SAndroid Build Coastguard Worker super().__init__() 171*da0073e9SAndroid Build Coastguard Worker self.device = "cuda:0" 172*da0073e9SAndroid Build Coastguard Worker self.metrics_dict = metrics_dict 173*da0073e9SAndroid Build Coastguard Worker self.request_queue = request_queue 174*da0073e9SAndroid Build Coastguard Worker self.response_queue = response_queue 175*da0073e9SAndroid Build Coastguard Worker self.read_requests_event = read_requests_event 176*da0073e9SAndroid Build Coastguard Worker self.batch_size = batch_size 177*da0073e9SAndroid Build Coastguard Worker self.num_workers = num_workers 178*da0073e9SAndroid Build Coastguard Worker self.model_dir = model_dir 179*da0073e9SAndroid Build Coastguard Worker self.compile_model = compile_model 180*da0073e9SAndroid Build Coastguard Worker self._setup_complete = False 181*da0073e9SAndroid Build Coastguard Worker self.h2d_stream = torch.cuda.Stream() 182*da0073e9SAndroid Build Coastguard Worker self.d2h_stream = torch.cuda.Stream() 183*da0073e9SAndroid Build Coastguard Worker # maps thread_id to the cuda.Stream associated with that worker thread 184*da0073e9SAndroid Build Coastguard Worker self.stream_map = {} 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker def _setup(self): 187*da0073e9SAndroid Build Coastguard Worker import time 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker from torchvision.models.resnet import BasicBlock, ResNet 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker import torch 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker # Create ResNet18 on meta device 194*da0073e9SAndroid Build Coastguard Worker with torch.device("meta"): 195*da0073e9SAndroid Build Coastguard Worker m = ResNet(BasicBlock, [2, 2, 2, 2]) 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker # Load pretrained weights 198*da0073e9SAndroid Build Coastguard Worker start_load_time = time.time() 199*da0073e9SAndroid Build Coastguard Worker state_dict = torch.load( 200*da0073e9SAndroid Build Coastguard Worker f"{self.model_dir}/resnet18-f37072fd.pth", 201*da0073e9SAndroid Build Coastguard Worker mmap=True, 202*da0073e9SAndroid Build Coastguard Worker map_location=self.device, 203*da0073e9SAndroid Build Coastguard Worker ) 204*da0073e9SAndroid Build Coastguard Worker self.metrics_dict["torch_load_time"] = time.time() - start_load_time 205*da0073e9SAndroid Build Coastguard Worker m.load_state_dict(state_dict, assign=True) 206*da0073e9SAndroid Build Coastguard Worker m.eval() 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker if self.compile_model: 209*da0073e9SAndroid Build Coastguard Worker start_compile_time = time.time() 210*da0073e9SAndroid Build Coastguard Worker m.compile() 211*da0073e9SAndroid Build Coastguard Worker end_compile_time = time.time() 212*da0073e9SAndroid Build Coastguard Worker self.metrics_dict["m_compile_time"] = end_compile_time - start_compile_time 213*da0073e9SAndroid Build Coastguard Worker return m 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker def model_predict( 216*da0073e9SAndroid Build Coastguard Worker self, 217*da0073e9SAndroid Build Coastguard Worker model, 218*da0073e9SAndroid Build Coastguard Worker input_buffer, 219*da0073e9SAndroid Build Coastguard Worker copy_event, 220*da0073e9SAndroid Build Coastguard Worker compute_event, 221*da0073e9SAndroid Build Coastguard Worker copy_sem, 222*da0073e9SAndroid Build Coastguard Worker compute_sem, 223*da0073e9SAndroid Build Coastguard Worker response_list, 224*da0073e9SAndroid Build Coastguard Worker request_time, 225*da0073e9SAndroid Build Coastguard Worker ): 226*da0073e9SAndroid Build Coastguard Worker # copy_sem makes sure copy_event has been recorded in the data copying thread 227*da0073e9SAndroid Build Coastguard Worker copy_sem.acquire() 228*da0073e9SAndroid Build Coastguard Worker self.stream_map[threading.get_native_id()].wait_event(copy_event) 229*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(self.stream_map[threading.get_native_id()]): 230*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 231*da0073e9SAndroid Build Coastguard Worker response_list.append(model(input_buffer)) 232*da0073e9SAndroid Build Coastguard Worker compute_event.record() 233*da0073e9SAndroid Build Coastguard Worker compute_sem.release() 234*da0073e9SAndroid Build Coastguard Worker del input_buffer 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker def copy_data(self, input_buffer, data, copy_event, copy_sem): 237*da0073e9SAndroid Build Coastguard Worker data = data.pin_memory() 238*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(self.h2d_stream): 239*da0073e9SAndroid Build Coastguard Worker input_buffer.copy_(data, non_blocking=True) 240*da0073e9SAndroid Build Coastguard Worker copy_event.record() 241*da0073e9SAndroid Build Coastguard Worker copy_sem.release() 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker def respond(self, compute_event, compute_sem, response_list, request_time): 244*da0073e9SAndroid Build Coastguard Worker # compute_sem makes sure compute_event has been recorded in the model_predict thread 245*da0073e9SAndroid Build Coastguard Worker compute_sem.acquire() 246*da0073e9SAndroid Build Coastguard Worker self.d2h_stream.wait_event(compute_event) 247*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(self.d2h_stream): 248*da0073e9SAndroid Build Coastguard Worker self.response_queue.put((response_list[0].cpu(), request_time)) 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker async def run(self): 251*da0073e9SAndroid Build Coastguard Worker def worker_initializer(): 252*da0073e9SAndroid Build Coastguard Worker self.stream_map[threading.get_native_id()] = torch.cuda.Stream() 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker worker_pool = ThreadPoolExecutor( 255*da0073e9SAndroid Build Coastguard Worker max_workers=self.num_workers, initializer=worker_initializer 256*da0073e9SAndroid Build Coastguard Worker ) 257*da0073e9SAndroid Build Coastguard Worker h2d_pool = ThreadPoolExecutor(max_workers=1) 258*da0073e9SAndroid Build Coastguard Worker d2h_pool = ThreadPoolExecutor(max_workers=1) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker self.read_requests_event.wait() 261*da0073e9SAndroid Build Coastguard Worker # Clear as we will wait for this event again before continuing to 262*da0073e9SAndroid Build Coastguard Worker # poll the request_queue for the non-warmup requests 263*da0073e9SAndroid Build Coastguard Worker self.read_requests_event.clear() 264*da0073e9SAndroid Build Coastguard Worker while True: 265*da0073e9SAndroid Build Coastguard Worker try: 266*da0073e9SAndroid Build Coastguard Worker data, request_time = self.request_queue.get(timeout=5) 267*da0073e9SAndroid Build Coastguard Worker except Empty: 268*da0073e9SAndroid Build Coastguard Worker break 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker if not self._setup_complete: 271*da0073e9SAndroid Build Coastguard Worker model = self._setup() 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker copy_sem = threading.Semaphore(0) 274*da0073e9SAndroid Build Coastguard Worker compute_sem = threading.Semaphore(0) 275*da0073e9SAndroid Build Coastguard Worker copy_event = torch.cuda.Event() 276*da0073e9SAndroid Build Coastguard Worker compute_event = torch.cuda.Event() 277*da0073e9SAndroid Build Coastguard Worker response_list = [] 278*da0073e9SAndroid Build Coastguard Worker input_buffer = torch.empty( 279*da0073e9SAndroid Build Coastguard Worker [self.batch_size, 3, 250, 250], dtype=torch.float32, device="cuda" 280*da0073e9SAndroid Build Coastguard Worker ) 281*da0073e9SAndroid Build Coastguard Worker asyncio.get_running_loop().run_in_executor( 282*da0073e9SAndroid Build Coastguard Worker h2d_pool, 283*da0073e9SAndroid Build Coastguard Worker self.copy_data, 284*da0073e9SAndroid Build Coastguard Worker input_buffer, 285*da0073e9SAndroid Build Coastguard Worker data, 286*da0073e9SAndroid Build Coastguard Worker copy_event, 287*da0073e9SAndroid Build Coastguard Worker copy_sem, 288*da0073e9SAndroid Build Coastguard Worker ) 289*da0073e9SAndroid Build Coastguard Worker asyncio.get_running_loop().run_in_executor( 290*da0073e9SAndroid Build Coastguard Worker worker_pool, 291*da0073e9SAndroid Build Coastguard Worker self.model_predict, 292*da0073e9SAndroid Build Coastguard Worker model, 293*da0073e9SAndroid Build Coastguard Worker input_buffer, 294*da0073e9SAndroid Build Coastguard Worker copy_event, 295*da0073e9SAndroid Build Coastguard Worker compute_event, 296*da0073e9SAndroid Build Coastguard Worker copy_sem, 297*da0073e9SAndroid Build Coastguard Worker compute_sem, 298*da0073e9SAndroid Build Coastguard Worker response_list, 299*da0073e9SAndroid Build Coastguard Worker request_time, 300*da0073e9SAndroid Build Coastguard Worker ) 301*da0073e9SAndroid Build Coastguard Worker asyncio.get_running_loop().run_in_executor( 302*da0073e9SAndroid Build Coastguard Worker d2h_pool, 303*da0073e9SAndroid Build Coastguard Worker self.respond, 304*da0073e9SAndroid Build Coastguard Worker compute_event, 305*da0073e9SAndroid Build Coastguard Worker compute_sem, 306*da0073e9SAndroid Build Coastguard Worker response_list, 307*da0073e9SAndroid Build Coastguard Worker request_time, 308*da0073e9SAndroid Build Coastguard Worker ) 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker if not self._setup_complete: 311*da0073e9SAndroid Build Coastguard Worker self.read_requests_event.wait() 312*da0073e9SAndroid Build Coastguard Worker self._setup_complete = True 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 316*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser() 317*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--num_iters", type=int, default=100) 318*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--batch_size", type=int, default=32) 319*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--model_dir", type=str, default=".") 320*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 321*da0073e9SAndroid Build Coastguard Worker "--compile", default=True, action=argparse.BooleanOptionalAction 322*da0073e9SAndroid Build Coastguard Worker ) 323*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--output_file", type=str, default="output.csv") 324*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 325*da0073e9SAndroid Build Coastguard Worker "--profile", default=False, action=argparse.BooleanOptionalAction 326*da0073e9SAndroid Build Coastguard Worker ) 327*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--num_workers", type=int, default=4) 328*da0073e9SAndroid Build Coastguard Worker args = parser.parse_args() 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker downloaded_checkpoint = False 331*da0073e9SAndroid Build Coastguard Worker if not os.path.isfile(f"{args.model_dir}/resnet18-f37072fd.pth"): 332*da0073e9SAndroid Build Coastguard Worker p = subprocess.run( 333*da0073e9SAndroid Build Coastguard Worker [ 334*da0073e9SAndroid Build Coastguard Worker "wget", 335*da0073e9SAndroid Build Coastguard Worker "https://download.pytorch.org/models/resnet18-f37072fd.pth", 336*da0073e9SAndroid Build Coastguard Worker ] 337*da0073e9SAndroid Build Coastguard Worker ) 338*da0073e9SAndroid Build Coastguard Worker if p.returncode == 0: 339*da0073e9SAndroid Build Coastguard Worker downloaded_checkpoint = True 340*da0073e9SAndroid Build Coastguard Worker else: 341*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Failed to download checkpoint") 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker try: 344*da0073e9SAndroid Build Coastguard Worker mp.set_start_method("forkserver") 345*da0073e9SAndroid Build Coastguard Worker request_queue = mp.Queue() 346*da0073e9SAndroid Build Coastguard Worker response_queue = mp.Queue() 347*da0073e9SAndroid Build Coastguard Worker read_requests_event = mp.Event() 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker manager = mp.Manager() 350*da0073e9SAndroid Build Coastguard Worker metrics_dict = manager.dict() 351*da0073e9SAndroid Build Coastguard Worker metrics_dict["batch_size"] = args.batch_size 352*da0073e9SAndroid Build Coastguard Worker metrics_dict["compile"] = args.compile 353*da0073e9SAndroid Build Coastguard Worker 354*da0073e9SAndroid Build Coastguard Worker frontend = FrontendWorker( 355*da0073e9SAndroid Build Coastguard Worker metrics_dict, 356*da0073e9SAndroid Build Coastguard Worker request_queue, 357*da0073e9SAndroid Build Coastguard Worker response_queue, 358*da0073e9SAndroid Build Coastguard Worker read_requests_event, 359*da0073e9SAndroid Build Coastguard Worker args.batch_size, 360*da0073e9SAndroid Build Coastguard Worker num_iters=args.num_iters, 361*da0073e9SAndroid Build Coastguard Worker ) 362*da0073e9SAndroid Build Coastguard Worker backend = BackendWorker( 363*da0073e9SAndroid Build Coastguard Worker metrics_dict, 364*da0073e9SAndroid Build Coastguard Worker request_queue, 365*da0073e9SAndroid Build Coastguard Worker response_queue, 366*da0073e9SAndroid Build Coastguard Worker read_requests_event, 367*da0073e9SAndroid Build Coastguard Worker args.batch_size, 368*da0073e9SAndroid Build Coastguard Worker args.num_workers, 369*da0073e9SAndroid Build Coastguard Worker args.model_dir, 370*da0073e9SAndroid Build Coastguard Worker args.compile, 371*da0073e9SAndroid Build Coastguard Worker ) 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker frontend.start() 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker if args.profile: 376*da0073e9SAndroid Build Coastguard Worker 377*da0073e9SAndroid Build Coastguard Worker def trace_handler(prof): 378*da0073e9SAndroid Build Coastguard Worker prof.export_chrome_trace("trace.json") 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker with torch.profiler.profile(on_trace_ready=trace_handler) as prof: 381*da0073e9SAndroid Build Coastguard Worker asyncio.run(backend.run()) 382*da0073e9SAndroid Build Coastguard Worker else: 383*da0073e9SAndroid Build Coastguard Worker asyncio.run(backend.run()) 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker frontend.join() 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker metrics_dict = {k: [v] for k, v in metrics_dict._getvalue().items()} 388*da0073e9SAndroid Build Coastguard Worker output = pd.DataFrame.from_dict(metrics_dict, orient="columns") 389*da0073e9SAndroid Build Coastguard Worker output_file = "./results/" + args.output_file 390*da0073e9SAndroid Build Coastguard Worker is_empty = not os.path.isfile(output_file) 391*da0073e9SAndroid Build Coastguard Worker 392*da0073e9SAndroid Build Coastguard Worker with open(output_file, "a+", newline="") as file: 393*da0073e9SAndroid Build Coastguard Worker output.to_csv(file, header=is_empty, index=False) 394*da0073e9SAndroid Build Coastguard Worker 395*da0073e9SAndroid Build Coastguard Worker finally: 396*da0073e9SAndroid Build Coastguard Worker # Cleanup checkpoint file if we downloaded it 397*da0073e9SAndroid Build Coastguard Worker if downloaded_checkpoint: 398*da0073e9SAndroid Build Coastguard Worker os.remove(f"{args.model_dir}/resnet18-f37072fd.pth") 399