1import argparse 2import json 3import os 4import time 5 6from coordinator import CoordinatorBase 7 8import torch.distributed.rpc as rpc 9import torch.multiprocessing as mp 10 11 12COORDINATOR_NAME = "coordinator" 13AGENT_NAME = "agent" 14OBSERVER_NAME = "observer{}" 15 16TOTAL_EPISODES = 10 17TOTAL_EPISODE_STEPS = 100 18 19 20def str2bool(v): 21 if isinstance(v, bool): 22 return v 23 if v.lower() in ("yes", "true", "t", "y", "1"): 24 return True 25 elif v.lower() in ("no", "false", "f", "n", "0"): 26 return False 27 else: 28 raise argparse.ArgumentTypeError("Boolean value expected.") 29 30 31parser = argparse.ArgumentParser(description="PyTorch RPC RL Benchmark") 32parser.add_argument("--world-size", "--world_size", type=str, default="10") 33parser.add_argument("--master-addr", "--master_addr", type=str, default="127.0.0.1") 34parser.add_argument("--master-port", "--master_port", type=str, default="29501") 35parser.add_argument("--batch", type=str, default="True") 36 37parser.add_argument("--state-size", "--state_size", type=str, default="10-20-10") 38parser.add_argument("--nlayers", type=str, default="5") 39parser.add_argument("--out-features", "--out_features", type=str, default="10") 40parser.add_argument( 41 "--output-file-path", 42 "--output_file_path", 43 type=str, 44 default="benchmark_report.json", 45) 46 47args = parser.parse_args() 48args = vars(args) 49 50 51def run_worker( 52 rank, 53 world_size, 54 master_addr, 55 master_port, 56 batch, 57 state_size, 58 nlayers, 59 out_features, 60 queue, 61): 62 r""" 63 inits an rpc worker 64 Args: 65 rank (int): Rpc rank of worker machine 66 world_size (int): Number of workers in rpc network (number of observers + 67 1 agent + 1 coordinator) 68 master_addr (str): Master address of cooridator 69 master_port (str): Master port of coordinator 70 batch (bool): Whether agent will use batching or process one observer 71 request a at a time 72 state_size (str): Numerical str representing state dimensions (ie: 5-15-10) 73 nlayers (int): Number of layers in model 74 out_features (int): Number of out features in model 75 queue (SimpleQueue): SimpleQueue from torch.multiprocessing.get_context() for 76 saving benchmark run results to 77 """ 78 state_size = list(map(int, state_size.split("-"))) 79 batch_size = world_size - 2 # No. of observers 80 81 os.environ["MASTER_ADDR"] = master_addr 82 os.environ["MASTER_PORT"] = master_port 83 if rank == 0: 84 rpc.init_rpc(COORDINATOR_NAME, rank=rank, world_size=world_size) 85 86 coordinator = CoordinatorBase( 87 batch_size, batch, state_size, nlayers, out_features 88 ) 89 coordinator.run_coordinator(TOTAL_EPISODES, TOTAL_EPISODE_STEPS, queue) 90 91 elif rank == 1: 92 rpc.init_rpc(AGENT_NAME, rank=rank, world_size=world_size) 93 else: 94 rpc.init_rpc(OBSERVER_NAME.format(rank), rank=rank, world_size=world_size) 95 rpc.shutdown() 96 97 98def find_graph_variable(args): 99 r""" 100 Determines if user specified multiple entries for a single argument, in which case 101 benchmark is run for each of these entries. Comma separated values in a given argument indicate multiple entries. 102 Output is presented so that user can use plot repo to plot the results with each of the 103 variable argument's entries on the x-axis. Args is modified in accordance with this. 104 More than 1 argument with multiple entries is not permitted. 105 Args: 106 args (dict): Dictionary containing arguments passed by the user (and default arguments) 107 """ 108 var_types = { 109 "world_size": int, 110 "state_size": str, 111 "nlayers": int, 112 "out_features": int, 113 "batch": str2bool, 114 } 115 for arg in var_types.keys(): 116 if "," in args[arg]: 117 if args.get("x_axis_name"): 118 raise ValueError("Only 1 x axis graph variable allowed") 119 args[arg] = list( 120 map(var_types[arg], args[arg].split(",")) 121 ) # convert , separated str to list 122 args["x_axis_name"] = arg 123 else: 124 args[arg] = var_types[arg](args[arg]) # convert string to proper type 125 126 127def append_spaces(string, length): 128 r""" 129 Returns a modified string with spaces appended to the end. If length of string argument 130 is greater than or equal to length, a single space is appended, otherwise x spaces are appended 131 where x is the difference between the length of string and the length argument 132 Args: 133 string (str): String to be modified 134 length (int): Size of desired return string with spaces appended 135 Return: (str) 136 """ 137 string = str(string) 138 offset = length - len(string) 139 if offset <= 0: 140 offset = 1 141 string += " " * offset 142 return string 143 144 145def print_benchmark_results(report): 146 r""" 147 Prints benchmark results 148 Args: 149 report (dict): JSON formatted dictionary containing relevant data on the run of this application 150 """ 151 print("--------------------------------------------------------------") 152 print("PyTorch distributed rpc benchmark reinforcement learning suite") 153 print("--------------------------------------------------------------") 154 for key, val in report.items(): 155 if key != "benchmark_results": 156 print(f"{key} : {val}") 157 158 x_axis_name = report.get("x_axis_name") 159 col_width = 7 160 heading = "" 161 if x_axis_name: 162 x_axis_output_label = f"{x_axis_name} |" 163 heading += append_spaces(x_axis_output_label, col_width) 164 metric_headers = [ 165 "agent latency (seconds)", 166 "agent throughput", 167 "observer latency (seconds)", 168 "observer throughput", 169 ] 170 percentile_subheaders = ["p50", "p75", "p90", "p95"] 171 subheading = "" 172 if x_axis_name: 173 subheading += append_spaces(" " * (len(x_axis_output_label) - 1), col_width) 174 for header in metric_headers: 175 heading += append_spaces(header, col_width * len(percentile_subheaders)) 176 for percentile in percentile_subheaders: 177 subheading += append_spaces(percentile, col_width) 178 print(heading) 179 print(subheading) 180 181 for benchmark_run in report["benchmark_results"]: 182 run_results = "" 183 if x_axis_name: 184 run_results += append_spaces( 185 benchmark_run[x_axis_name], max(col_width, len(x_axis_output_label)) 186 ) 187 for metric_name in metric_headers: 188 percentile_results = benchmark_run[metric_name] 189 for percentile in percentile_subheaders: 190 run_results += append_spaces(percentile_results[percentile], col_width) 191 print(run_results) 192 193 194def main(): 195 r""" 196 Runs rpc benchmark once if no argument has multiple entries, and otherwise once for each of the multiple entries. 197 Multiple entries is indicated by comma separated values, and may only be done for a single argument. 198 Results are printed as well as saved to output file. In case of multiple entries for a single argument, 199 the plot repo can be used to benchmark results on the y axis with each entry on the x axis. 200 """ 201 find_graph_variable(args) 202 203 # run once if no x axis variables 204 x_axis_variables = args[args["x_axis_name"]] if args.get("x_axis_name") else [None] 205 ctx = mp.get_context("spawn") 206 queue = ctx.SimpleQueue() 207 benchmark_runs = [] 208 for i, x_axis_variable in enumerate( 209 x_axis_variables 210 ): # run benchmark for every x axis variable 211 if len(x_axis_variables) > 1: 212 args[ 213 args["x_axis_name"] 214 ] = x_axis_variable # set x axis variable for this benchmark iteration 215 processes = [] 216 start_time = time.time() 217 for rank in range(args["world_size"]): 218 prc = ctx.Process( 219 target=run_worker, 220 args=( 221 rank, 222 args["world_size"], 223 args["master_addr"], 224 args["master_port"], 225 args["batch"], 226 args["state_size"], 227 args["nlayers"], 228 args["out_features"], 229 queue, 230 ), 231 ) 232 prc.start() 233 processes.append(prc) 234 benchmark_run_results = queue.get() 235 for process in processes: 236 process.join() 237 print(f"Time taken benchmark run {i} -, {time.time() - start_time}") 238 if args.get("x_axis_name"): 239 # save x axis value was for this iteration in the results 240 benchmark_run_results[args["x_axis_name"]] = x_axis_variable 241 benchmark_runs.append(benchmark_run_results) 242 243 report = args 244 report["benchmark_results"] = benchmark_runs 245 if args.get("x_axis_name"): 246 # x_axis_name was variable so dont save a constant in the report for that variable 247 del report[args["x_axis_name"]] 248 with open(args["output_file_path"], "w") as f: 249 json.dump(report, f) 250 print_benchmark_results(report) 251 252 253if __name__ == "__main__": 254 main() 255