1import random 2import time 3 4from agent import AgentBase 5 6import torch 7import torch.distributed.rpc as rpc 8from torch.distributed.rpc import rpc_sync 9 10 11class ObserverBase: 12 def __init__(self): 13 r""" 14 Inits observer class 15 """ 16 self.id = rpc.get_worker_info().id 17 18 def set_state(self, state_size, batch): 19 r""" 20 Further initializes observer to be aware of rpc environment 21 Args: 22 state_size (list): List of integers denoting dimensions of state 23 batch (bool): Whether agent will be using batch select action 24 """ 25 self.state_size = state_size 26 self.select_action = ( 27 AgentBase.select_action_batch 28 if batch 29 else AgentBase.select_action_non_batch 30 ) 31 32 def reset(self): 33 r""" 34 Resets state randomly 35 """ 36 state = torch.rand(self.state_size) 37 return state 38 39 def step(self, action): 40 r""" 41 Generates random state and reward 42 Args: 43 action (int): Int received from agent representing action to take on state 44 """ 45 state = torch.rand(self.state_size) 46 reward = random.randint(0, 1) 47 48 return state, reward 49 50 def run_ob_episode(self, agent_rref, n_steps): 51 r""" 52 Runs single observer episode where for n_steps, an action is selected 53 from the agent based on curent state and state is updated 54 Args: 55 agent_rref (RRef): Remote Reference to the agent 56 n_steps (int): Number of times to select an action to transform state per episode 57 """ 58 state, ep_reward = self.reset(), None 59 rewards = torch.zeros(n_steps) 60 observer_latencies = [] 61 observer_throughput = [] 62 63 for st in range(n_steps): 64 ob_latency_start = time.time() 65 action = rpc_sync( 66 agent_rref.owner(), 67 self.select_action, 68 args=(agent_rref, self.id, state), 69 ) 70 71 ob_latency = time.time() - ob_latency_start 72 observer_latencies.append(ob_latency) 73 observer_throughput.append(1 / ob_latency) 74 75 state, reward = self.step(action) 76 rewards[st] = reward 77 78 return [rewards, ep_reward, observer_latencies, observer_throughput] 79