1# mypy: allow-untyped-defs
2
3# If you need to modify this file to make this test pass, please also apply same edits accordingly to
4# https://github.com/pytorch/examples/blob/master/distributed/rpc/rl/main.py
5# and https://pytorch.org/tutorials/intermediate/rpc_tutorial.html
6
7import numpy as np
8from itertools import count
9
10import torch
11import torch.distributed.rpc as rpc
12import torch.nn as nn
13import torch.nn.functional as F
14import torch.optim as optim
15from torch.distributed.rpc import RRef, rpc_sync, rpc_async, remote
16from torch.distributions import Categorical
17
18from torch.testing._internal.dist_utils import dist_init, worker_name
19from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import RpcAgentTestFixture
20
21TOTAL_EPISODE_STEP = 5000
22GAMMA = 0.1
23SEED = 543
24
25def _call_method(method, rref, *args, **kwargs):
26    r"""
27    a helper function to call a method on the given RRef
28    """
29    return method(rref.local_value(), *args, **kwargs)
30
31
32def _remote_method(method, rref, *args, **kwargs):
33    r"""
34    a helper function to run method on the owner of rref and fetch back the
35    result using RPC
36    """
37    args = [method, rref] + list(args)
38    return rpc_sync(rref.owner(), _call_method, args=args, kwargs=kwargs)
39
40
41class Policy(nn.Module):
42    r"""
43    Borrowing the ``Policy`` class from the Reinforcement Learning example.
44    Copying the code to make these two examples independent.
45    See https://github.com/pytorch/examples/tree/master/reinforcement_learning
46    """
47    def __init__(self) -> None:
48        super().__init__()
49        self.affine1 = nn.Linear(4, 128)
50        self.dropout = nn.Dropout(p=0.6)
51        self.affine2 = nn.Linear(128, 2)
52
53        self.saved_log_probs = []
54        self.rewards = []
55
56    def forward(self, x):
57        x = self.affine1(x)
58        x = self.dropout(x)
59        x = F.relu(x)
60        action_scores = self.affine2(x)
61        return F.softmax(action_scores, dim=1)
62
63
64class DummyEnv:
65    r"""
66    A dummy environment that implements the required subset of the OpenAI gym
67    interface. It exists only to avoid a dependency on gym for running the
68    tests in this file. It is designed to run for a set max number of iterations,
69    returning random states and rewards at each step.
70    """
71    def __init__(self, state_dim=4, num_iters=10, reward_threshold=475.0):
72        self.state_dim = state_dim
73        self.num_iters = num_iters
74        self.iter = 0
75        self.reward_threshold = reward_threshold
76
77    def seed(self, manual_seed):
78        torch.manual_seed(manual_seed)
79
80    def reset(self):
81        self.iter = 0
82        return torch.randn(self.state_dim)
83
84    def step(self, action):
85        self.iter += 1
86        state = torch.randn(self.state_dim)
87        reward = torch.rand(1).item() * self.reward_threshold
88        done = self.iter >= self.num_iters
89        info = {}
90        return state, reward, done, info
91
92
93class Observer:
94    r"""
95    An observer has exclusive access to its own environment. Each observer
96    captures the state from its environment, and send the state to the agent to
97    select an action. Then, the observer applies the action to its environment
98    and reports the reward to the agent.
99    """
100    def __init__(self) -> None:
101        self.id = rpc.get_worker_info().id
102        self.env = DummyEnv()
103        self.env.seed(SEED)
104
105    def run_episode(self, agent_rref, n_steps):
106        r"""
107        Run one episode of n_steps.
108        Arguments:
109            agent_rref (RRef): an RRef referencing the agent object.
110            n_steps (int): number of steps in this episode
111        """
112        state, ep_reward = self.env.reset(), 0
113        for step in range(n_steps):
114            # send the state to the agent to get an action
115            action = _remote_method(Agent.select_action, agent_rref, self.id, state)
116
117            # apply the action to the environment, and get the reward
118            state, reward, done, _ = self.env.step(action)
119
120            # report the reward to the agent for training purpose
121            _remote_method(Agent.report_reward, agent_rref, self.id, reward)
122
123            if done:
124                break
125
126
127class Agent:
128    def __init__(self, world_size):
129        self.ob_rrefs = []
130        self.agent_rref = RRef(self)
131        self.rewards = {}
132        self.saved_log_probs = {}
133        self.policy = Policy()
134        self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)
135        self.eps = np.finfo(np.float32).eps.item()
136        self.running_reward = 0
137        self.reward_threshold = DummyEnv().reward_threshold
138        for ob_rank in range(1, world_size):
139            ob_info = rpc.get_worker_info(worker_name(ob_rank))
140            self.ob_rrefs.append(remote(ob_info, Observer))
141            self.rewards[ob_info.id] = []
142            self.saved_log_probs[ob_info.id] = []
143
144    def select_action(self, ob_id, state):
145        r"""
146        This function is mostly borrowed from the Reinforcement Learning example.
147        See https://github.com/pytorch/examples/tree/master/reinforcement_learning
148        The main difference is that instead of keeping all probs in one list,
149        the agent keeps probs in a dictionary, one key per observer.
150
151        NB: no need to enforce thread-safety here as GIL will serialize
152        executions.
153        """
154        probs = self.policy(state.unsqueeze(0))
155        m = Categorical(probs)
156        action = m.sample()
157        self.saved_log_probs[ob_id].append(m.log_prob(action))
158        return action.item()
159
160    def report_reward(self, ob_id, reward):
161        r"""
162        Observers call this function to report rewards.
163        """
164        self.rewards[ob_id].append(reward)
165
166    def run_episode(self, n_steps=0):
167        r"""
168        Run one episode. The agent will tell each observer to run n_steps.
169        """
170        futs = []
171        for ob_rref in self.ob_rrefs:
172            # make async RPC to kick off an episode on all observers
173            futs.append(
174                rpc_async(
175                    ob_rref.owner(),
176                    _call_method,
177                    args=(Observer.run_episode, ob_rref, self.agent_rref, n_steps)
178                )
179            )
180
181        # wait until all observers have finished this episode
182        for fut in futs:
183            fut.wait()
184
185    def finish_episode(self):
186        r"""
187        This function is mostly borrowed from the Reinforcement Learning example.
188        See https://github.com/pytorch/examples/tree/master/reinforcement_learning
189        The main difference is that it joins all probs and rewards from
190        different observers into one list, and uses the minimum observer rewards
191        as the reward of the current episode.
192        """
193
194        # joins probs and rewards from different observers into lists
195        R, probs, rewards = 0, [], []
196        for ob_id in self.rewards:
197            probs.extend(self.saved_log_probs[ob_id])
198            rewards.extend(self.rewards[ob_id])
199
200        # use the minimum observer reward to calculate the running reward
201        min_reward = min(sum(self.rewards[ob_id]) for ob_id in self.rewards)
202        self.running_reward = 0.05 * min_reward + (1 - 0.05) * self.running_reward
203
204        # clear saved probs and rewards
205        for ob_id in self.rewards:
206            self.rewards[ob_id] = []
207            self.saved_log_probs[ob_id] = []
208
209        policy_loss, returns = [], []
210        for r in rewards[::-1]:
211            R = r + GAMMA * R
212            returns.insert(0, R)
213        returns = torch.tensor(returns)
214        returns = (returns - returns.mean()) / (returns.std() + self.eps)
215        for log_prob, R in zip(probs, returns):
216            policy_loss.append(-log_prob * R)
217        self.optimizer.zero_grad()
218        policy_loss = torch.cat(policy_loss).sum()
219        policy_loss.backward()
220        self.optimizer.step()
221        return min_reward
222
223
224def run_agent(agent, n_steps):
225    for i_episode in count(1):
226        agent.run_episode(n_steps=n_steps)
227        last_reward = agent.finish_episode()
228
229        if agent.running_reward > agent.reward_threshold:
230            print(f"Solved! Running reward is now {agent.running_reward}!")
231            break
232
233
234class ReinforcementLearningRpcTest(RpcAgentTestFixture):
235    @dist_init(setup_rpc=False)
236    def test_rl_rpc(self):
237        if self.rank == 0:
238            # Rank 0 is the agent.
239            rpc.init_rpc(
240                name=worker_name(self.rank),
241                backend=self.rpc_backend,
242                rank=self.rank,
243                world_size=self.world_size,
244                rpc_backend_options=self.rpc_backend_options,
245            )
246            agent = Agent(self.world_size)
247            run_agent(agent, n_steps=int(TOTAL_EPISODE_STEP / (self.world_size - 1)))
248
249            # Ensure training was run. We don't really care about whether the task was learned,
250            # since the purpose of the test is to check the API calls.
251            self.assertGreater(agent.running_reward, 0.0)
252        else:
253            # Other ranks are observers that passively wait for instructions from the agent.
254            rpc.init_rpc(
255                name=worker_name(self.rank),
256                backend=self.rpc_backend,
257                rank=self.rank,
258                world_size=self.world_size,
259                rpc_backend_options=self.rpc_backend_options,
260            )
261        rpc.shutdown()
262