xref: /aosp_15_r20/external/pytorch/benchmarks/inference/server.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport argparse
2*da0073e9SAndroid Build Coastguard Workerimport asyncio
3*da0073e9SAndroid Build Coastguard Workerimport os.path
4*da0073e9SAndroid Build Coastguard Workerimport subprocess
5*da0073e9SAndroid Build Coastguard Workerimport threading
6*da0073e9SAndroid Build Coastguard Workerimport time
7*da0073e9SAndroid Build Coastguard Workerfrom concurrent.futures import ThreadPoolExecutor
8*da0073e9SAndroid Build Coastguard Workerfrom queue import Empty
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerimport numpy as np
11*da0073e9SAndroid Build Coastguard Workerimport pandas as pd
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Workerimport torch
14*da0073e9SAndroid Build Coastguard Workerimport torch.multiprocessing as mp
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerclass FrontendWorker(mp.Process):
18*da0073e9SAndroid Build Coastguard Worker    """
19*da0073e9SAndroid Build Coastguard Worker    This worker will send requests to a backend process, and measure the
20*da0073e9SAndroid Build Coastguard Worker    throughput and latency of those requests as well as GPU utilization.
21*da0073e9SAndroid Build Coastguard Worker    """
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker    def __init__(
24*da0073e9SAndroid Build Coastguard Worker        self,
25*da0073e9SAndroid Build Coastguard Worker        metrics_dict,
26*da0073e9SAndroid Build Coastguard Worker        request_queue,
27*da0073e9SAndroid Build Coastguard Worker        response_queue,
28*da0073e9SAndroid Build Coastguard Worker        read_requests_event,
29*da0073e9SAndroid Build Coastguard Worker        batch_size,
30*da0073e9SAndroid Build Coastguard Worker        num_iters=10,
31*da0073e9SAndroid Build Coastguard Worker    ):
32*da0073e9SAndroid Build Coastguard Worker        super().__init__()
33*da0073e9SAndroid Build Coastguard Worker        self.metrics_dict = metrics_dict
34*da0073e9SAndroid Build Coastguard Worker        self.request_queue = request_queue
35*da0073e9SAndroid Build Coastguard Worker        self.response_queue = response_queue
36*da0073e9SAndroid Build Coastguard Worker        self.read_requests_event = read_requests_event
37*da0073e9SAndroid Build Coastguard Worker        self.warmup_event = mp.Event()
38*da0073e9SAndroid Build Coastguard Worker        self.batch_size = batch_size
39*da0073e9SAndroid Build Coastguard Worker        self.num_iters = num_iters
40*da0073e9SAndroid Build Coastguard Worker        self.poll_gpu = True
41*da0073e9SAndroid Build Coastguard Worker        self.start_send_time = None
42*da0073e9SAndroid Build Coastguard Worker        self.end_recv_time = None
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker    def _run_metrics(self, metrics_lock):
45*da0073e9SAndroid Build Coastguard Worker        """
46*da0073e9SAndroid Build Coastguard Worker        This function will poll the response queue until it has received all
47*da0073e9SAndroid Build Coastguard Worker        responses. It records the startup latency, the average, max, min latency
48*da0073e9SAndroid Build Coastguard Worker        as well as througput of requests.
49*da0073e9SAndroid Build Coastguard Worker        """
50*da0073e9SAndroid Build Coastguard Worker        warmup_response_time = None
51*da0073e9SAndroid Build Coastguard Worker        response_times = []
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker        for i in range(self.num_iters + 1):
54*da0073e9SAndroid Build Coastguard Worker            response, request_time = self.response_queue.get()
55*da0073e9SAndroid Build Coastguard Worker            if warmup_response_time is None:
56*da0073e9SAndroid Build Coastguard Worker                self.warmup_event.set()
57*da0073e9SAndroid Build Coastguard Worker                warmup_response_time = time.time() - request_time
58*da0073e9SAndroid Build Coastguard Worker            else:
59*da0073e9SAndroid Build Coastguard Worker                response_times.append(time.time() - request_time)
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker        self.end_recv_time = time.time()
62*da0073e9SAndroid Build Coastguard Worker        self.poll_gpu = False
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker        response_times = np.array(response_times)
65*da0073e9SAndroid Build Coastguard Worker        with metrics_lock:
66*da0073e9SAndroid Build Coastguard Worker            self.metrics_dict["warmup_latency"] = warmup_response_time
67*da0073e9SAndroid Build Coastguard Worker            self.metrics_dict["average_latency"] = response_times.mean()
68*da0073e9SAndroid Build Coastguard Worker            self.metrics_dict["max_latency"] = response_times.max()
69*da0073e9SAndroid Build Coastguard Worker            self.metrics_dict["min_latency"] = response_times.min()
70*da0073e9SAndroid Build Coastguard Worker            self.metrics_dict["throughput"] = (self.num_iters * self.batch_size) / (
71*da0073e9SAndroid Build Coastguard Worker                self.end_recv_time - self.start_send_time
72*da0073e9SAndroid Build Coastguard Worker            )
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker    def _run_gpu_utilization(self, metrics_lock):
75*da0073e9SAndroid Build Coastguard Worker        """
76*da0073e9SAndroid Build Coastguard Worker        This function will poll nvidia-smi for GPU utilization every 100ms to
77*da0073e9SAndroid Build Coastguard Worker        record the average GPU utilization.
78*da0073e9SAndroid Build Coastguard Worker        """
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker        def get_gpu_utilization():
81*da0073e9SAndroid Build Coastguard Worker            try:
82*da0073e9SAndroid Build Coastguard Worker                nvidia_smi_output = subprocess.check_output(
83*da0073e9SAndroid Build Coastguard Worker                    [
84*da0073e9SAndroid Build Coastguard Worker                        "nvidia-smi",
85*da0073e9SAndroid Build Coastguard Worker                        "--query-gpu=utilization.gpu",
86*da0073e9SAndroid Build Coastguard Worker                        "--id=0",
87*da0073e9SAndroid Build Coastguard Worker                        "--format=csv,noheader,nounits",
88*da0073e9SAndroid Build Coastguard Worker                    ]
89*da0073e9SAndroid Build Coastguard Worker                )
90*da0073e9SAndroid Build Coastguard Worker                gpu_utilization = nvidia_smi_output.decode().strip()
91*da0073e9SAndroid Build Coastguard Worker                return gpu_utilization
92*da0073e9SAndroid Build Coastguard Worker            except subprocess.CalledProcessError:
93*da0073e9SAndroid Build Coastguard Worker                return "N/A"
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker        gpu_utilizations = []
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker        while self.poll_gpu:
98*da0073e9SAndroid Build Coastguard Worker            gpu_utilization = get_gpu_utilization()
99*da0073e9SAndroid Build Coastguard Worker            if gpu_utilization != "N/A":
100*da0073e9SAndroid Build Coastguard Worker                gpu_utilizations.append(float(gpu_utilization))
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker        with metrics_lock:
103*da0073e9SAndroid Build Coastguard Worker            self.metrics_dict["gpu_util"] = torch.tensor(gpu_utilizations).mean().item()
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker    def _send_requests(self):
106*da0073e9SAndroid Build Coastguard Worker        """
107*da0073e9SAndroid Build Coastguard Worker        This function will send one warmup request, and then num_iters requests
108*da0073e9SAndroid Build Coastguard Worker        to the backend process.
109*da0073e9SAndroid Build Coastguard Worker        """
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker        fake_data = torch.randn(self.batch_size, 3, 250, 250, requires_grad=False)
112*da0073e9SAndroid Build Coastguard Worker        other_data = [
113*da0073e9SAndroid Build Coastguard Worker            torch.randn(self.batch_size, 3, 250, 250, requires_grad=False)
114*da0073e9SAndroid Build Coastguard Worker            for i in range(self.num_iters)
115*da0073e9SAndroid Build Coastguard Worker        ]
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker        # Send one batch of warmup data
118*da0073e9SAndroid Build Coastguard Worker        self.request_queue.put((fake_data, time.time()))
119*da0073e9SAndroid Build Coastguard Worker        # Tell backend to poll queue for warmup request
120*da0073e9SAndroid Build Coastguard Worker        self.read_requests_event.set()
121*da0073e9SAndroid Build Coastguard Worker        self.warmup_event.wait()
122*da0073e9SAndroid Build Coastguard Worker        # Tell backend to poll queue for rest of requests
123*da0073e9SAndroid Build Coastguard Worker        self.read_requests_event.set()
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker        # Send fake data
126*da0073e9SAndroid Build Coastguard Worker        self.start_send_time = time.time()
127*da0073e9SAndroid Build Coastguard Worker        for i in range(self.num_iters):
128*da0073e9SAndroid Build Coastguard Worker            self.request_queue.put((other_data[i], time.time()))
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker    def run(self):
131*da0073e9SAndroid Build Coastguard Worker        # Lock for writing to metrics_dict
132*da0073e9SAndroid Build Coastguard Worker        metrics_lock = threading.Lock()
133*da0073e9SAndroid Build Coastguard Worker        requests_thread = threading.Thread(target=self._send_requests)
134*da0073e9SAndroid Build Coastguard Worker        metrics_thread = threading.Thread(
135*da0073e9SAndroid Build Coastguard Worker            target=self._run_metrics, args=(metrics_lock,)
136*da0073e9SAndroid Build Coastguard Worker        )
137*da0073e9SAndroid Build Coastguard Worker        gpu_utilization_thread = threading.Thread(
138*da0073e9SAndroid Build Coastguard Worker            target=self._run_gpu_utilization, args=(metrics_lock,)
139*da0073e9SAndroid Build Coastguard Worker        )
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker        requests_thread.start()
142*da0073e9SAndroid Build Coastguard Worker        metrics_thread.start()
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker        # only start polling GPU utilization after the warmup request is complete
145*da0073e9SAndroid Build Coastguard Worker        self.warmup_event.wait()
146*da0073e9SAndroid Build Coastguard Worker        gpu_utilization_thread.start()
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker        requests_thread.join()
149*da0073e9SAndroid Build Coastguard Worker        metrics_thread.join()
150*da0073e9SAndroid Build Coastguard Worker        gpu_utilization_thread.join()
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Workerclass BackendWorker:
154*da0073e9SAndroid Build Coastguard Worker    """
155*da0073e9SAndroid Build Coastguard Worker    This worker will take tensors from the request queue, do some computation,
156*da0073e9SAndroid Build Coastguard Worker    and then return the result back in the response queue.
157*da0073e9SAndroid Build Coastguard Worker    """
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker    def __init__(
160*da0073e9SAndroid Build Coastguard Worker        self,
161*da0073e9SAndroid Build Coastguard Worker        metrics_dict,
162*da0073e9SAndroid Build Coastguard Worker        request_queue,
163*da0073e9SAndroid Build Coastguard Worker        response_queue,
164*da0073e9SAndroid Build Coastguard Worker        read_requests_event,
165*da0073e9SAndroid Build Coastguard Worker        batch_size,
166*da0073e9SAndroid Build Coastguard Worker        num_workers,
167*da0073e9SAndroid Build Coastguard Worker        model_dir=".",
168*da0073e9SAndroid Build Coastguard Worker        compile_model=True,
169*da0073e9SAndroid Build Coastguard Worker    ):
170*da0073e9SAndroid Build Coastguard Worker        super().__init__()
171*da0073e9SAndroid Build Coastguard Worker        self.device = "cuda:0"
172*da0073e9SAndroid Build Coastguard Worker        self.metrics_dict = metrics_dict
173*da0073e9SAndroid Build Coastguard Worker        self.request_queue = request_queue
174*da0073e9SAndroid Build Coastguard Worker        self.response_queue = response_queue
175*da0073e9SAndroid Build Coastguard Worker        self.read_requests_event = read_requests_event
176*da0073e9SAndroid Build Coastguard Worker        self.batch_size = batch_size
177*da0073e9SAndroid Build Coastguard Worker        self.num_workers = num_workers
178*da0073e9SAndroid Build Coastguard Worker        self.model_dir = model_dir
179*da0073e9SAndroid Build Coastguard Worker        self.compile_model = compile_model
180*da0073e9SAndroid Build Coastguard Worker        self._setup_complete = False
181*da0073e9SAndroid Build Coastguard Worker        self.h2d_stream = torch.cuda.Stream()
182*da0073e9SAndroid Build Coastguard Worker        self.d2h_stream = torch.cuda.Stream()
183*da0073e9SAndroid Build Coastguard Worker        # maps thread_id to the cuda.Stream associated with that worker thread
184*da0073e9SAndroid Build Coastguard Worker        self.stream_map = {}
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker    def _setup(self):
187*da0073e9SAndroid Build Coastguard Worker        import time
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker        from torchvision.models.resnet import BasicBlock, ResNet
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker        import torch
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker        # Create ResNet18 on meta device
194*da0073e9SAndroid Build Coastguard Worker        with torch.device("meta"):
195*da0073e9SAndroid Build Coastguard Worker            m = ResNet(BasicBlock, [2, 2, 2, 2])
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker        # Load pretrained weights
198*da0073e9SAndroid Build Coastguard Worker        start_load_time = time.time()
199*da0073e9SAndroid Build Coastguard Worker        state_dict = torch.load(
200*da0073e9SAndroid Build Coastguard Worker            f"{self.model_dir}/resnet18-f37072fd.pth",
201*da0073e9SAndroid Build Coastguard Worker            mmap=True,
202*da0073e9SAndroid Build Coastguard Worker            map_location=self.device,
203*da0073e9SAndroid Build Coastguard Worker        )
204*da0073e9SAndroid Build Coastguard Worker        self.metrics_dict["torch_load_time"] = time.time() - start_load_time
205*da0073e9SAndroid Build Coastguard Worker        m.load_state_dict(state_dict, assign=True)
206*da0073e9SAndroid Build Coastguard Worker        m.eval()
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker        if self.compile_model:
209*da0073e9SAndroid Build Coastguard Worker            start_compile_time = time.time()
210*da0073e9SAndroid Build Coastguard Worker            m.compile()
211*da0073e9SAndroid Build Coastguard Worker            end_compile_time = time.time()
212*da0073e9SAndroid Build Coastguard Worker            self.metrics_dict["m_compile_time"] = end_compile_time - start_compile_time
213*da0073e9SAndroid Build Coastguard Worker        return m
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker    def model_predict(
216*da0073e9SAndroid Build Coastguard Worker        self,
217*da0073e9SAndroid Build Coastguard Worker        model,
218*da0073e9SAndroid Build Coastguard Worker        input_buffer,
219*da0073e9SAndroid Build Coastguard Worker        copy_event,
220*da0073e9SAndroid Build Coastguard Worker        compute_event,
221*da0073e9SAndroid Build Coastguard Worker        copy_sem,
222*da0073e9SAndroid Build Coastguard Worker        compute_sem,
223*da0073e9SAndroid Build Coastguard Worker        response_list,
224*da0073e9SAndroid Build Coastguard Worker        request_time,
225*da0073e9SAndroid Build Coastguard Worker    ):
226*da0073e9SAndroid Build Coastguard Worker        # copy_sem makes sure copy_event has been recorded in the data copying thread
227*da0073e9SAndroid Build Coastguard Worker        copy_sem.acquire()
228*da0073e9SAndroid Build Coastguard Worker        self.stream_map[threading.get_native_id()].wait_event(copy_event)
229*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(self.stream_map[threading.get_native_id()]):
230*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
231*da0073e9SAndroid Build Coastguard Worker                response_list.append(model(input_buffer))
232*da0073e9SAndroid Build Coastguard Worker                compute_event.record()
233*da0073e9SAndroid Build Coastguard Worker                compute_sem.release()
234*da0073e9SAndroid Build Coastguard Worker        del input_buffer
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker    def copy_data(self, input_buffer, data, copy_event, copy_sem):
237*da0073e9SAndroid Build Coastguard Worker        data = data.pin_memory()
238*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(self.h2d_stream):
239*da0073e9SAndroid Build Coastguard Worker            input_buffer.copy_(data, non_blocking=True)
240*da0073e9SAndroid Build Coastguard Worker            copy_event.record()
241*da0073e9SAndroid Build Coastguard Worker            copy_sem.release()
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker    def respond(self, compute_event, compute_sem, response_list, request_time):
244*da0073e9SAndroid Build Coastguard Worker        # compute_sem makes sure compute_event has been recorded in the model_predict thread
245*da0073e9SAndroid Build Coastguard Worker        compute_sem.acquire()
246*da0073e9SAndroid Build Coastguard Worker        self.d2h_stream.wait_event(compute_event)
247*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(self.d2h_stream):
248*da0073e9SAndroid Build Coastguard Worker            self.response_queue.put((response_list[0].cpu(), request_time))
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker    async def run(self):
251*da0073e9SAndroid Build Coastguard Worker        def worker_initializer():
252*da0073e9SAndroid Build Coastguard Worker            self.stream_map[threading.get_native_id()] = torch.cuda.Stream()
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker        worker_pool = ThreadPoolExecutor(
255*da0073e9SAndroid Build Coastguard Worker            max_workers=self.num_workers, initializer=worker_initializer
256*da0073e9SAndroid Build Coastguard Worker        )
257*da0073e9SAndroid Build Coastguard Worker        h2d_pool = ThreadPoolExecutor(max_workers=1)
258*da0073e9SAndroid Build Coastguard Worker        d2h_pool = ThreadPoolExecutor(max_workers=1)
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker        self.read_requests_event.wait()
261*da0073e9SAndroid Build Coastguard Worker        # Clear as we will wait for this event again before continuing to
262*da0073e9SAndroid Build Coastguard Worker        # poll the request_queue for the non-warmup requests
263*da0073e9SAndroid Build Coastguard Worker        self.read_requests_event.clear()
264*da0073e9SAndroid Build Coastguard Worker        while True:
265*da0073e9SAndroid Build Coastguard Worker            try:
266*da0073e9SAndroid Build Coastguard Worker                data, request_time = self.request_queue.get(timeout=5)
267*da0073e9SAndroid Build Coastguard Worker            except Empty:
268*da0073e9SAndroid Build Coastguard Worker                break
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker            if not self._setup_complete:
271*da0073e9SAndroid Build Coastguard Worker                model = self._setup()
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker            copy_sem = threading.Semaphore(0)
274*da0073e9SAndroid Build Coastguard Worker            compute_sem = threading.Semaphore(0)
275*da0073e9SAndroid Build Coastguard Worker            copy_event = torch.cuda.Event()
276*da0073e9SAndroid Build Coastguard Worker            compute_event = torch.cuda.Event()
277*da0073e9SAndroid Build Coastguard Worker            response_list = []
278*da0073e9SAndroid Build Coastguard Worker            input_buffer = torch.empty(
279*da0073e9SAndroid Build Coastguard Worker                [self.batch_size, 3, 250, 250], dtype=torch.float32, device="cuda"
280*da0073e9SAndroid Build Coastguard Worker            )
281*da0073e9SAndroid Build Coastguard Worker            asyncio.get_running_loop().run_in_executor(
282*da0073e9SAndroid Build Coastguard Worker                h2d_pool,
283*da0073e9SAndroid Build Coastguard Worker                self.copy_data,
284*da0073e9SAndroid Build Coastguard Worker                input_buffer,
285*da0073e9SAndroid Build Coastguard Worker                data,
286*da0073e9SAndroid Build Coastguard Worker                copy_event,
287*da0073e9SAndroid Build Coastguard Worker                copy_sem,
288*da0073e9SAndroid Build Coastguard Worker            )
289*da0073e9SAndroid Build Coastguard Worker            asyncio.get_running_loop().run_in_executor(
290*da0073e9SAndroid Build Coastguard Worker                worker_pool,
291*da0073e9SAndroid Build Coastguard Worker                self.model_predict,
292*da0073e9SAndroid Build Coastguard Worker                model,
293*da0073e9SAndroid Build Coastguard Worker                input_buffer,
294*da0073e9SAndroid Build Coastguard Worker                copy_event,
295*da0073e9SAndroid Build Coastguard Worker                compute_event,
296*da0073e9SAndroid Build Coastguard Worker                copy_sem,
297*da0073e9SAndroid Build Coastguard Worker                compute_sem,
298*da0073e9SAndroid Build Coastguard Worker                response_list,
299*da0073e9SAndroid Build Coastguard Worker                request_time,
300*da0073e9SAndroid Build Coastguard Worker            )
301*da0073e9SAndroid Build Coastguard Worker            asyncio.get_running_loop().run_in_executor(
302*da0073e9SAndroid Build Coastguard Worker                d2h_pool,
303*da0073e9SAndroid Build Coastguard Worker                self.respond,
304*da0073e9SAndroid Build Coastguard Worker                compute_event,
305*da0073e9SAndroid Build Coastguard Worker                compute_sem,
306*da0073e9SAndroid Build Coastguard Worker                response_list,
307*da0073e9SAndroid Build Coastguard Worker                request_time,
308*da0073e9SAndroid Build Coastguard Worker            )
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker            if not self._setup_complete:
311*da0073e9SAndroid Build Coastguard Worker                self.read_requests_event.wait()
312*da0073e9SAndroid Build Coastguard Worker                self._setup_complete = True
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
316*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser()
317*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--num_iters", type=int, default=100)
318*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--batch_size", type=int, default=32)
319*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--model_dir", type=str, default=".")
320*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
321*da0073e9SAndroid Build Coastguard Worker        "--compile", default=True, action=argparse.BooleanOptionalAction
322*da0073e9SAndroid Build Coastguard Worker    )
323*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--output_file", type=str, default="output.csv")
324*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
325*da0073e9SAndroid Build Coastguard Worker        "--profile", default=False, action=argparse.BooleanOptionalAction
326*da0073e9SAndroid Build Coastguard Worker    )
327*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--num_workers", type=int, default=4)
328*da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker    downloaded_checkpoint = False
331*da0073e9SAndroid Build Coastguard Worker    if not os.path.isfile(f"{args.model_dir}/resnet18-f37072fd.pth"):
332*da0073e9SAndroid Build Coastguard Worker        p = subprocess.run(
333*da0073e9SAndroid Build Coastguard Worker            [
334*da0073e9SAndroid Build Coastguard Worker                "wget",
335*da0073e9SAndroid Build Coastguard Worker                "https://download.pytorch.org/models/resnet18-f37072fd.pth",
336*da0073e9SAndroid Build Coastguard Worker            ]
337*da0073e9SAndroid Build Coastguard Worker        )
338*da0073e9SAndroid Build Coastguard Worker        if p.returncode == 0:
339*da0073e9SAndroid Build Coastguard Worker            downloaded_checkpoint = True
340*da0073e9SAndroid Build Coastguard Worker        else:
341*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("Failed to download checkpoint")
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker    try:
344*da0073e9SAndroid Build Coastguard Worker        mp.set_start_method("forkserver")
345*da0073e9SAndroid Build Coastguard Worker        request_queue = mp.Queue()
346*da0073e9SAndroid Build Coastguard Worker        response_queue = mp.Queue()
347*da0073e9SAndroid Build Coastguard Worker        read_requests_event = mp.Event()
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker        manager = mp.Manager()
350*da0073e9SAndroid Build Coastguard Worker        metrics_dict = manager.dict()
351*da0073e9SAndroid Build Coastguard Worker        metrics_dict["batch_size"] = args.batch_size
352*da0073e9SAndroid Build Coastguard Worker        metrics_dict["compile"] = args.compile
353*da0073e9SAndroid Build Coastguard Worker
354*da0073e9SAndroid Build Coastguard Worker        frontend = FrontendWorker(
355*da0073e9SAndroid Build Coastguard Worker            metrics_dict,
356*da0073e9SAndroid Build Coastguard Worker            request_queue,
357*da0073e9SAndroid Build Coastguard Worker            response_queue,
358*da0073e9SAndroid Build Coastguard Worker            read_requests_event,
359*da0073e9SAndroid Build Coastguard Worker            args.batch_size,
360*da0073e9SAndroid Build Coastguard Worker            num_iters=args.num_iters,
361*da0073e9SAndroid Build Coastguard Worker        )
362*da0073e9SAndroid Build Coastguard Worker        backend = BackendWorker(
363*da0073e9SAndroid Build Coastguard Worker            metrics_dict,
364*da0073e9SAndroid Build Coastguard Worker            request_queue,
365*da0073e9SAndroid Build Coastguard Worker            response_queue,
366*da0073e9SAndroid Build Coastguard Worker            read_requests_event,
367*da0073e9SAndroid Build Coastguard Worker            args.batch_size,
368*da0073e9SAndroid Build Coastguard Worker            args.num_workers,
369*da0073e9SAndroid Build Coastguard Worker            args.model_dir,
370*da0073e9SAndroid Build Coastguard Worker            args.compile,
371*da0073e9SAndroid Build Coastguard Worker        )
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker        frontend.start()
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker        if args.profile:
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Worker            def trace_handler(prof):
378*da0073e9SAndroid Build Coastguard Worker                prof.export_chrome_trace("trace.json")
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker            with torch.profiler.profile(on_trace_ready=trace_handler) as prof:
381*da0073e9SAndroid Build Coastguard Worker                asyncio.run(backend.run())
382*da0073e9SAndroid Build Coastguard Worker        else:
383*da0073e9SAndroid Build Coastguard Worker            asyncio.run(backend.run())
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker        frontend.join()
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker        metrics_dict = {k: [v] for k, v in metrics_dict._getvalue().items()}
388*da0073e9SAndroid Build Coastguard Worker        output = pd.DataFrame.from_dict(metrics_dict, orient="columns")
389*da0073e9SAndroid Build Coastguard Worker        output_file = "./results/" + args.output_file
390*da0073e9SAndroid Build Coastguard Worker        is_empty = not os.path.isfile(output_file)
391*da0073e9SAndroid Build Coastguard Worker
392*da0073e9SAndroid Build Coastguard Worker        with open(output_file, "a+", newline="") as file:
393*da0073e9SAndroid Build Coastguard Worker            output.to_csv(file, header=is_empty, index=False)
394*da0073e9SAndroid Build Coastguard Worker
395*da0073e9SAndroid Build Coastguard Worker    finally:
396*da0073e9SAndroid Build Coastguard Worker        # Cleanup checkpoint file if we downloaded it
397*da0073e9SAndroid Build Coastguard Worker        if downloaded_checkpoint:
398*da0073e9SAndroid Build Coastguard Worker            os.remove(f"{args.model_dir}/resnet18-f37072fd.pth")
399