1import operator 2import threading 3import time 4from functools import reduce 5 6import torch 7import torch.distributed.rpc as rpc 8import torch.nn as nn 9import torch.nn.functional as F 10import torch.optim as optim 11from torch.distributions import Categorical 12 13 14OBSERVER_NAME = "observer{}" 15 16 17class Policy(nn.Module): 18 def __init__(self, in_features, nlayers, out_features): 19 r""" 20 Inits policy class 21 Args: 22 in_features (int): Number of input features the model takes 23 nlayers (int): Number of layers in the model 24 out_features (int): Number of features the model outputs 25 """ 26 super().__init__() 27 28 self.model = nn.Sequential( 29 nn.Flatten(1, -1), 30 nn.Linear(in_features, out_features), 31 *[nn.Linear(out_features, out_features) for _ in range(nlayers)], 32 ) 33 self.dim = 0 34 35 def forward(self, x): 36 action_scores = self.model(x) 37 return F.softmax(action_scores, dim=self.dim) 38 39 40class AgentBase: 41 def __init__(self): 42 r""" 43 Inits agent class 44 """ 45 self.id = rpc.get_worker_info().id 46 self.running_reward = 0 47 self.eps = 1e-7 48 49 self.rewards = {} 50 51 self.future_actions = torch.futures.Future() 52 self.lock = threading.Lock() 53 54 self.agent_latency_start = None 55 self.agent_latency_end = None 56 self.agent_latency = [] 57 self.agent_throughput = [] 58 59 def reset_metrics(self): 60 r""" 61 Sets all benchmark metrics to their empty values 62 """ 63 self.agent_latency_start = None 64 self.agent_latency_end = None 65 self.agent_latency = [] 66 self.agent_throughput = [] 67 68 def set_world(self, batch_size, state_size, nlayers, out_features, batch=True): 69 r""" 70 Further initializes agent to be aware of rpc environment 71 Args: 72 batch_size (int): size of batches of observer requests to process 73 state_size (list): List of ints dictating the dimensions of the state 74 nlayers (int): Number of layers in the model 75 out_features (int): Number of out features in the model 76 batch (bool): Whether to process and respond to observer requests as a batch or 1 at a time 77 """ 78 self.batch = batch 79 self.policy = Policy(reduce(operator.mul, state_size), nlayers, out_features) 80 self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2) 81 82 self.batch_size = batch_size 83 for rank in range(batch_size): 84 ob_info = rpc.get_worker_info(OBSERVER_NAME.format(rank + 2)) 85 86 self.rewards[ob_info.id] = [] 87 88 self.saved_log_probs = ( 89 [] if self.batch else {k: [] for k in range(self.batch_size)} 90 ) 91 92 self.pending_states = self.batch_size 93 self.state_size = state_size 94 self.states = torch.zeros(self.batch_size, *state_size) 95 96 @staticmethod 97 @rpc.functions.async_execution 98 def select_action_batch(agent_rref, observer_id, state): 99 r""" 100 Receives state from an observer to select action for. Queues the observers's request 101 for an action until queue size equals batch size named during Agent initiation, at which point 102 actions are selected for all pending observer requests and communicated back to observers 103 Args: 104 agent_rref (RRef): RRFef of this agent 105 observer_id (int): Observer id of observer calling this function 106 state (Tensor): Tensor representing current state held by observer 107 """ 108 self = agent_rref.local_value() 109 observer_id -= 2 110 111 self.states[observer_id].copy_(state) 112 future_action = self.future_actions.then( 113 lambda future_actions: future_actions.wait()[observer_id].item() 114 ) 115 116 with self.lock: 117 if self.pending_states == self.batch_size: 118 self.agent_latency_start = time.time() 119 self.pending_states -= 1 120 if self.pending_states == 0: 121 self.pending_states = self.batch_size 122 probs = self.policy(self.states) 123 m = Categorical(probs) 124 actions = m.sample() 125 self.saved_log_probs.append(m.log_prob(actions).t()) 126 future_actions = self.future_actions 127 self.future_actions = torch.futures.Future() 128 future_actions.set_result(actions) 129 130 self.agent_latency_end = time.time() 131 132 batch_latency = self.agent_latency_end - self.agent_latency_start 133 self.agent_latency.append(batch_latency) 134 self.agent_throughput.append(self.batch_size / batch_latency) 135 136 return future_action 137 138 @staticmethod 139 def select_action_non_batch(agent_rref, observer_id, state): 140 r""" 141 Select actions based on observer state and communicates back to observer 142 Args: 143 agent_rref (RRef): RRef of this agent 144 observer_id (int): Observer id of observer calling this function 145 state (Tensor): Tensor representing current state held by observer 146 """ 147 self = agent_rref.local_value() 148 observer_id -= 2 149 agent_latency_start = time.time() 150 151 state = state.float().unsqueeze(0) 152 probs = self.policy(state) 153 m = Categorical(probs) 154 action = m.sample() 155 self.saved_log_probs[observer_id].append(m.log_prob(action)) 156 157 agent_latency_end = time.time() 158 non_batch_latency = agent_latency_end - agent_latency_start 159 self.agent_latency.append(non_batch_latency) 160 self.agent_throughput.append(1 / non_batch_latency) 161 162 return action.item() 163 164 def finish_episode(self, rets): 165 r""" 166 Finishes the episode 167 Args: 168 rets (list): List containing rewards generated by selct action calls during 169 episode run 170 """ 171 return self.agent_latency, self.agent_throughput 172