1# mypy: allow-untyped-defs 2 3# If you need to modify this file to make this test pass, please also apply same edits accordingly to 4# https://github.com/pytorch/examples/blob/master/distributed/rpc/rl/main.py 5# and https://pytorch.org/tutorials/intermediate/rpc_tutorial.html 6 7import numpy as np 8from itertools import count 9 10import torch 11import torch.distributed.rpc as rpc 12import torch.nn as nn 13import torch.nn.functional as F 14import torch.optim as optim 15from torch.distributed.rpc import RRef, rpc_sync, rpc_async, remote 16from torch.distributions import Categorical 17 18from torch.testing._internal.dist_utils import dist_init, worker_name 19from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import RpcAgentTestFixture 20 21TOTAL_EPISODE_STEP = 5000 22GAMMA = 0.1 23SEED = 543 24 25def _call_method(method, rref, *args, **kwargs): 26 r""" 27 a helper function to call a method on the given RRef 28 """ 29 return method(rref.local_value(), *args, **kwargs) 30 31 32def _remote_method(method, rref, *args, **kwargs): 33 r""" 34 a helper function to run method on the owner of rref and fetch back the 35 result using RPC 36 """ 37 args = [method, rref] + list(args) 38 return rpc_sync(rref.owner(), _call_method, args=args, kwargs=kwargs) 39 40 41class Policy(nn.Module): 42 r""" 43 Borrowing the ``Policy`` class from the Reinforcement Learning example. 44 Copying the code to make these two examples independent. 45 See https://github.com/pytorch/examples/tree/master/reinforcement_learning 46 """ 47 def __init__(self) -> None: 48 super().__init__() 49 self.affine1 = nn.Linear(4, 128) 50 self.dropout = nn.Dropout(p=0.6) 51 self.affine2 = nn.Linear(128, 2) 52 53 self.saved_log_probs = [] 54 self.rewards = [] 55 56 def forward(self, x): 57 x = self.affine1(x) 58 x = self.dropout(x) 59 x = F.relu(x) 60 action_scores = self.affine2(x) 61 return F.softmax(action_scores, dim=1) 62 63 64class DummyEnv: 65 r""" 66 A dummy environment that implements the required subset of the OpenAI gym 67 interface. It exists only to avoid a dependency on gym for running the 68 tests in this file. It is designed to run for a set max number of iterations, 69 returning random states and rewards at each step. 70 """ 71 def __init__(self, state_dim=4, num_iters=10, reward_threshold=475.0): 72 self.state_dim = state_dim 73 self.num_iters = num_iters 74 self.iter = 0 75 self.reward_threshold = reward_threshold 76 77 def seed(self, manual_seed): 78 torch.manual_seed(manual_seed) 79 80 def reset(self): 81 self.iter = 0 82 return torch.randn(self.state_dim) 83 84 def step(self, action): 85 self.iter += 1 86 state = torch.randn(self.state_dim) 87 reward = torch.rand(1).item() * self.reward_threshold 88 done = self.iter >= self.num_iters 89 info = {} 90 return state, reward, done, info 91 92 93class Observer: 94 r""" 95 An observer has exclusive access to its own environment. Each observer 96 captures the state from its environment, and send the state to the agent to 97 select an action. Then, the observer applies the action to its environment 98 and reports the reward to the agent. 99 """ 100 def __init__(self) -> None: 101 self.id = rpc.get_worker_info().id 102 self.env = DummyEnv() 103 self.env.seed(SEED) 104 105 def run_episode(self, agent_rref, n_steps): 106 r""" 107 Run one episode of n_steps. 108 Arguments: 109 agent_rref (RRef): an RRef referencing the agent object. 110 n_steps (int): number of steps in this episode 111 """ 112 state, ep_reward = self.env.reset(), 0 113 for step in range(n_steps): 114 # send the state to the agent to get an action 115 action = _remote_method(Agent.select_action, agent_rref, self.id, state) 116 117 # apply the action to the environment, and get the reward 118 state, reward, done, _ = self.env.step(action) 119 120 # report the reward to the agent for training purpose 121 _remote_method(Agent.report_reward, agent_rref, self.id, reward) 122 123 if done: 124 break 125 126 127class Agent: 128 def __init__(self, world_size): 129 self.ob_rrefs = [] 130 self.agent_rref = RRef(self) 131 self.rewards = {} 132 self.saved_log_probs = {} 133 self.policy = Policy() 134 self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2) 135 self.eps = np.finfo(np.float32).eps.item() 136 self.running_reward = 0 137 self.reward_threshold = DummyEnv().reward_threshold 138 for ob_rank in range(1, world_size): 139 ob_info = rpc.get_worker_info(worker_name(ob_rank)) 140 self.ob_rrefs.append(remote(ob_info, Observer)) 141 self.rewards[ob_info.id] = [] 142 self.saved_log_probs[ob_info.id] = [] 143 144 def select_action(self, ob_id, state): 145 r""" 146 This function is mostly borrowed from the Reinforcement Learning example. 147 See https://github.com/pytorch/examples/tree/master/reinforcement_learning 148 The main difference is that instead of keeping all probs in one list, 149 the agent keeps probs in a dictionary, one key per observer. 150 151 NB: no need to enforce thread-safety here as GIL will serialize 152 executions. 153 """ 154 probs = self.policy(state.unsqueeze(0)) 155 m = Categorical(probs) 156 action = m.sample() 157 self.saved_log_probs[ob_id].append(m.log_prob(action)) 158 return action.item() 159 160 def report_reward(self, ob_id, reward): 161 r""" 162 Observers call this function to report rewards. 163 """ 164 self.rewards[ob_id].append(reward) 165 166 def run_episode(self, n_steps=0): 167 r""" 168 Run one episode. The agent will tell each observer to run n_steps. 169 """ 170 futs = [] 171 for ob_rref in self.ob_rrefs: 172 # make async RPC to kick off an episode on all observers 173 futs.append( 174 rpc_async( 175 ob_rref.owner(), 176 _call_method, 177 args=(Observer.run_episode, ob_rref, self.agent_rref, n_steps) 178 ) 179 ) 180 181 # wait until all observers have finished this episode 182 for fut in futs: 183 fut.wait() 184 185 def finish_episode(self): 186 r""" 187 This function is mostly borrowed from the Reinforcement Learning example. 188 See https://github.com/pytorch/examples/tree/master/reinforcement_learning 189 The main difference is that it joins all probs and rewards from 190 different observers into one list, and uses the minimum observer rewards 191 as the reward of the current episode. 192 """ 193 194 # joins probs and rewards from different observers into lists 195 R, probs, rewards = 0, [], [] 196 for ob_id in self.rewards: 197 probs.extend(self.saved_log_probs[ob_id]) 198 rewards.extend(self.rewards[ob_id]) 199 200 # use the minimum observer reward to calculate the running reward 201 min_reward = min(sum(self.rewards[ob_id]) for ob_id in self.rewards) 202 self.running_reward = 0.05 * min_reward + (1 - 0.05) * self.running_reward 203 204 # clear saved probs and rewards 205 for ob_id in self.rewards: 206 self.rewards[ob_id] = [] 207 self.saved_log_probs[ob_id] = [] 208 209 policy_loss, returns = [], [] 210 for r in rewards[::-1]: 211 R = r + GAMMA * R 212 returns.insert(0, R) 213 returns = torch.tensor(returns) 214 returns = (returns - returns.mean()) / (returns.std() + self.eps) 215 for log_prob, R in zip(probs, returns): 216 policy_loss.append(-log_prob * R) 217 self.optimizer.zero_grad() 218 policy_loss = torch.cat(policy_loss).sum() 219 policy_loss.backward() 220 self.optimizer.step() 221 return min_reward 222 223 224def run_agent(agent, n_steps): 225 for i_episode in count(1): 226 agent.run_episode(n_steps=n_steps) 227 last_reward = agent.finish_episode() 228 229 if agent.running_reward > agent.reward_threshold: 230 print(f"Solved! Running reward is now {agent.running_reward}!") 231 break 232 233 234class ReinforcementLearningRpcTest(RpcAgentTestFixture): 235 @dist_init(setup_rpc=False) 236 def test_rl_rpc(self): 237 if self.rank == 0: 238 # Rank 0 is the agent. 239 rpc.init_rpc( 240 name=worker_name(self.rank), 241 backend=self.rpc_backend, 242 rank=self.rank, 243 world_size=self.world_size, 244 rpc_backend_options=self.rpc_backend_options, 245 ) 246 agent = Agent(self.world_size) 247 run_agent(agent, n_steps=int(TOTAL_EPISODE_STEP / (self.world_size - 1))) 248 249 # Ensure training was run. We don't really care about whether the task was learned, 250 # since the purpose of the test is to check the API calls. 251 self.assertGreater(agent.running_reward, 0.0) 252 else: 253 # Other ranks are observers that passively wait for instructions from the agent. 254 rpc.init_rpc( 255 name=worker_name(self.rank), 256 backend=self.rpc_backend, 257 rank=self.rank, 258 world_size=self.world_size, 259 rpc_backend_options=self.rpc_backend_options, 260 ) 261 rpc.shutdown() 262