xref: /aosp_15_r20/external/pytorch/benchmarks/distributed/rpc/rl/observer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import random
2import time
3
4from agent import AgentBase
5
6import torch
7import torch.distributed.rpc as rpc
8from torch.distributed.rpc import rpc_sync
9
10
11class ObserverBase:
12    def __init__(self):
13        r"""
14        Inits observer class
15        """
16        self.id = rpc.get_worker_info().id
17
18    def set_state(self, state_size, batch):
19        r"""
20        Further initializes observer to be aware of rpc environment
21        Args:
22            state_size (list): List of integers denoting dimensions of state
23            batch (bool): Whether agent will be using batch select action
24        """
25        self.state_size = state_size
26        self.select_action = (
27            AgentBase.select_action_batch
28            if batch
29            else AgentBase.select_action_non_batch
30        )
31
32    def reset(self):
33        r"""
34        Resets state randomly
35        """
36        state = torch.rand(self.state_size)
37        return state
38
39    def step(self, action):
40        r"""
41        Generates random state and reward
42        Args:
43            action (int): Int received from agent representing action to take on state
44        """
45        state = torch.rand(self.state_size)
46        reward = random.randint(0, 1)
47
48        return state, reward
49
50    def run_ob_episode(self, agent_rref, n_steps):
51        r"""
52        Runs single observer episode where for n_steps, an action is selected
53        from the agent based on curent state and state is updated
54        Args:
55            agent_rref (RRef): Remote Reference to the agent
56            n_steps (int): Number of times to select an action to transform state per episode
57        """
58        state, ep_reward = self.reset(), None
59        rewards = torch.zeros(n_steps)
60        observer_latencies = []
61        observer_throughput = []
62
63        for st in range(n_steps):
64            ob_latency_start = time.time()
65            action = rpc_sync(
66                agent_rref.owner(),
67                self.select_action,
68                args=(agent_rref, self.id, state),
69            )
70
71            ob_latency = time.time() - ob_latency_start
72            observer_latencies.append(ob_latency)
73            observer_throughput.append(1 / ob_latency)
74
75            state, reward = self.step(action)
76            rewards[st] = reward
77
78        return [rewards, ep_reward, observer_latencies, observer_throughput]
79