xref: /aosp_15_r20/external/pytorch/benchmarks/distributed/rpc/rl/coordinator.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import time
2
3import numpy as np
4from agent import AgentBase
5from observer import ObserverBase
6
7import torch
8import torch.distributed.rpc as rpc
9
10
11COORDINATOR_NAME = "coordinator"
12AGENT_NAME = "agent"
13OBSERVER_NAME = "observer{}"
14
15EPISODE_STEPS = 100
16
17
18class CoordinatorBase:
19    def __init__(self, batch_size, batch, state_size, nlayers, out_features):
20        r"""
21        Coordinator object to run on worker.  Only one coordinator exists.  Responsible
22        for facilitating communication between agent and observers and recording benchmark
23        throughput and latency data.
24        Args:
25            batch_size (int): Number of observer requests to process in a batch
26            batch (bool): Whether to process and respond to observer requests as a batch or 1 at a time
27            state_size (list): List of ints dictating the dimensions of the state
28            nlayers (int): Number of layers in the model
29            out_features (int): Number of out features in the model
30        """
31        self.batch_size = batch_size
32        self.batch = batch
33
34        self.agent_rref = None  # Agent RRef
35        self.ob_rrefs = []  # Observer RRef
36
37        agent_info = rpc.get_worker_info(AGENT_NAME)
38        self.agent_rref = rpc.remote(agent_info, AgentBase)
39
40        for rank in range(batch_size):
41            ob_info = rpc.get_worker_info(OBSERVER_NAME.format(rank + 2))
42            ob_ref = rpc.remote(ob_info, ObserverBase)
43            self.ob_rrefs.append(ob_ref)
44
45            ob_ref.rpc_sync().set_state(state_size, batch)
46
47        self.agent_rref.rpc_sync().set_world(
48            batch_size, state_size, nlayers, out_features, self.batch
49        )
50
51    def run_coordinator(self, episodes, episode_steps, queue):
52        r"""
53        Runs n benchmark episodes.  Each episode is started by coordinator telling each
54        observer to contact the agent.  Each episode is concluded by coordinator telling agent
55        to finish the episode, and then the coordinator records benchmark data
56        Args:
57            episodes (int): Number of episodes to run
58            episode_steps (int): Number steps to be run in each episdoe by each observer
59            queue (SimpleQueue): SimpleQueue from torch.multiprocessing.get_context() for
60                                 saving benchmark run results to
61        """
62
63        agent_latency_final = []
64        agent_throughput_final = []
65
66        observer_latency_final = []
67        observer_throughput_final = []
68
69        for ep in range(episodes):
70            ep_start_time = time.time()
71
72            print(f"Episode {ep} - ", end="")
73
74            n_steps = episode_steps
75            agent_start_time = time.time()
76
77            futs = []
78            for ob_rref in self.ob_rrefs:
79                futs.append(
80                    ob_rref.rpc_async().run_ob_episode(self.agent_rref, n_steps)
81                )
82
83            rets = torch.futures.wait_all(futs)
84            agent_latency, agent_throughput = self.agent_rref.rpc_sync().finish_episode(
85                rets
86            )
87
88            self.agent_rref.rpc_sync().reset_metrics()
89
90            agent_latency_final += agent_latency
91            agent_throughput_final += agent_throughput
92
93            observer_latency_final += [ret[2] for ret in rets]
94            observer_throughput_final += [ret[3] for ret in rets]
95
96            ep_end_time = time.time()
97            episode_time = ep_end_time - ep_start_time
98            print(round(episode_time, 3))
99
100        observer_latency_final = [t for s in observer_latency_final for t in s]
101        observer_throughput_final = [t for s in observer_throughput_final for t in s]
102
103        benchmark_metrics = {
104            "agent latency (seconds)": {},
105            "agent throughput": {},
106            "observer latency (seconds)": {},
107            "observer throughput": {},
108        }
109
110        print(f"For batch size {self.batch_size}")
111        print("\nAgent Latency - ", len(agent_latency_final))
112        agent_latency_final = sorted(agent_latency_final)
113        for p in [50, 75, 90, 95]:
114            v = np.percentile(agent_latency_final, p)
115            print("p" + str(p) + ":", round(v, 3))
116            p = f"p{p}"
117            benchmark_metrics["agent latency (seconds)"][p] = round(v, 3)
118
119        print("\nAgent Throughput - ", len(agent_throughput_final))
120        agent_throughput_final = sorted(agent_throughput_final)
121        for p in [50, 75, 90, 95]:
122            v = np.percentile(agent_throughput_final, p)
123            print("p" + str(p) + ":", int(v))
124            p = f"p{p}"
125            benchmark_metrics["agent throughput"][p] = int(v)
126
127        print("\nObserver Latency - ", len(observer_latency_final))
128        observer_latency_final = sorted(observer_latency_final)
129        for p in [50, 75, 90, 95]:
130            v = np.percentile(observer_latency_final, p)
131            print("p" + str(p) + ":", round(v, 3))
132            p = f"p{p}"
133            benchmark_metrics["observer latency (seconds)"][p] = round(v, 3)
134
135        print("\nObserver Throughput - ", len(observer_throughput_final))
136        observer_throughput_final = sorted(observer_throughput_final)
137        for p in [50, 75, 90, 95]:
138            v = np.percentile(observer_throughput_final, p)
139            print("p" + str(p) + ":", int(v))
140            p = f"p{p}"
141            benchmark_metrics["observer throughput"][p] = int(v)
142
143        if queue:
144            queue.put(benchmark_metrics)
145