xref: /aosp_15_r20/external/pytorch/benchmarks/distributed/rpc/rl/agent.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import operator
2import threading
3import time
4from functools import reduce
5
6import torch
7import torch.distributed.rpc as rpc
8import torch.nn as nn
9import torch.nn.functional as F
10import torch.optim as optim
11from torch.distributions import Categorical
12
13
14OBSERVER_NAME = "observer{}"
15
16
17class Policy(nn.Module):
18    def __init__(self, in_features, nlayers, out_features):
19        r"""
20        Inits policy class
21        Args:
22            in_features (int): Number of input features the model takes
23            nlayers (int): Number of layers in the model
24            out_features (int): Number of features the model outputs
25        """
26        super().__init__()
27
28        self.model = nn.Sequential(
29            nn.Flatten(1, -1),
30            nn.Linear(in_features, out_features),
31            *[nn.Linear(out_features, out_features) for _ in range(nlayers)],
32        )
33        self.dim = 0
34
35    def forward(self, x):
36        action_scores = self.model(x)
37        return F.softmax(action_scores, dim=self.dim)
38
39
40class AgentBase:
41    def __init__(self):
42        r"""
43        Inits agent class
44        """
45        self.id = rpc.get_worker_info().id
46        self.running_reward = 0
47        self.eps = 1e-7
48
49        self.rewards = {}
50
51        self.future_actions = torch.futures.Future()
52        self.lock = threading.Lock()
53
54        self.agent_latency_start = None
55        self.agent_latency_end = None
56        self.agent_latency = []
57        self.agent_throughput = []
58
59    def reset_metrics(self):
60        r"""
61        Sets all benchmark metrics to their empty values
62        """
63        self.agent_latency_start = None
64        self.agent_latency_end = None
65        self.agent_latency = []
66        self.agent_throughput = []
67
68    def set_world(self, batch_size, state_size, nlayers, out_features, batch=True):
69        r"""
70        Further initializes agent to be aware of rpc environment
71        Args:
72            batch_size (int): size of batches of observer requests to process
73            state_size (list): List of ints dictating the dimensions of the state
74            nlayers (int): Number of layers in the model
75            out_features (int): Number of out features in the model
76            batch (bool): Whether to process and respond to observer requests as a batch or 1 at a time
77        """
78        self.batch = batch
79        self.policy = Policy(reduce(operator.mul, state_size), nlayers, out_features)
80        self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)
81
82        self.batch_size = batch_size
83        for rank in range(batch_size):
84            ob_info = rpc.get_worker_info(OBSERVER_NAME.format(rank + 2))
85
86            self.rewards[ob_info.id] = []
87
88        self.saved_log_probs = (
89            [] if self.batch else {k: [] for k in range(self.batch_size)}
90        )
91
92        self.pending_states = self.batch_size
93        self.state_size = state_size
94        self.states = torch.zeros(self.batch_size, *state_size)
95
96    @staticmethod
97    @rpc.functions.async_execution
98    def select_action_batch(agent_rref, observer_id, state):
99        r"""
100        Receives state from an observer to select action for.  Queues the observers's request
101        for an action until queue size equals batch size named during Agent initiation, at which point
102        actions are selected for all pending observer requests and communicated back to observers
103        Args:
104            agent_rref (RRef): RRFef of this agent
105            observer_id (int): Observer id of observer calling this function
106            state (Tensor): Tensor representing current state held by observer
107        """
108        self = agent_rref.local_value()
109        observer_id -= 2
110
111        self.states[observer_id].copy_(state)
112        future_action = self.future_actions.then(
113            lambda future_actions: future_actions.wait()[observer_id].item()
114        )
115
116        with self.lock:
117            if self.pending_states == self.batch_size:
118                self.agent_latency_start = time.time()
119            self.pending_states -= 1
120            if self.pending_states == 0:
121                self.pending_states = self.batch_size
122                probs = self.policy(self.states)
123                m = Categorical(probs)
124                actions = m.sample()
125                self.saved_log_probs.append(m.log_prob(actions).t())
126                future_actions = self.future_actions
127                self.future_actions = torch.futures.Future()
128                future_actions.set_result(actions)
129
130                self.agent_latency_end = time.time()
131
132                batch_latency = self.agent_latency_end - self.agent_latency_start
133                self.agent_latency.append(batch_latency)
134                self.agent_throughput.append(self.batch_size / batch_latency)
135
136        return future_action
137
138    @staticmethod
139    def select_action_non_batch(agent_rref, observer_id, state):
140        r"""
141        Select actions based on observer state and communicates back to observer
142        Args:
143            agent_rref (RRef): RRef of this agent
144            observer_id (int): Observer id of observer calling this function
145            state (Tensor): Tensor representing current state held by observer
146        """
147        self = agent_rref.local_value()
148        observer_id -= 2
149        agent_latency_start = time.time()
150
151        state = state.float().unsqueeze(0)
152        probs = self.policy(state)
153        m = Categorical(probs)
154        action = m.sample()
155        self.saved_log_probs[observer_id].append(m.log_prob(action))
156
157        agent_latency_end = time.time()
158        non_batch_latency = agent_latency_end - agent_latency_start
159        self.agent_latency.append(non_batch_latency)
160        self.agent_throughput.append(1 / non_batch_latency)
161
162        return action.item()
163
164    def finish_episode(self, rets):
165        r"""
166        Finishes the episode
167        Args:
168            rets (list): List containing rewards generated by selct action calls during
169            episode run
170        """
171        return self.agent_latency, self.agent_throughput
172