xref: /aosp_15_r20/external/pytorch/benchmarks/distributed/rpc/rl/launcher.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import json
3import os
4import time
5
6from coordinator import CoordinatorBase
7
8import torch.distributed.rpc as rpc
9import torch.multiprocessing as mp
10
11
12COORDINATOR_NAME = "coordinator"
13AGENT_NAME = "agent"
14OBSERVER_NAME = "observer{}"
15
16TOTAL_EPISODES = 10
17TOTAL_EPISODE_STEPS = 100
18
19
20def str2bool(v):
21    if isinstance(v, bool):
22        return v
23    if v.lower() in ("yes", "true", "t", "y", "1"):
24        return True
25    elif v.lower() in ("no", "false", "f", "n", "0"):
26        return False
27    else:
28        raise argparse.ArgumentTypeError("Boolean value expected.")
29
30
31parser = argparse.ArgumentParser(description="PyTorch RPC RL Benchmark")
32parser.add_argument("--world-size", "--world_size", type=str, default="10")
33parser.add_argument("--master-addr", "--master_addr", type=str, default="127.0.0.1")
34parser.add_argument("--master-port", "--master_port", type=str, default="29501")
35parser.add_argument("--batch", type=str, default="True")
36
37parser.add_argument("--state-size", "--state_size", type=str, default="10-20-10")
38parser.add_argument("--nlayers", type=str, default="5")
39parser.add_argument("--out-features", "--out_features", type=str, default="10")
40parser.add_argument(
41    "--output-file-path",
42    "--output_file_path",
43    type=str,
44    default="benchmark_report.json",
45)
46
47args = parser.parse_args()
48args = vars(args)
49
50
51def run_worker(
52    rank,
53    world_size,
54    master_addr,
55    master_port,
56    batch,
57    state_size,
58    nlayers,
59    out_features,
60    queue,
61):
62    r"""
63    inits an rpc worker
64    Args:
65        rank (int): Rpc rank of worker machine
66        world_size (int): Number of workers in rpc network (number of observers +
67                          1 agent + 1 coordinator)
68        master_addr (str): Master address of cooridator
69        master_port (str): Master port of coordinator
70        batch (bool): Whether agent will use batching or process one observer
71                      request a at a time
72        state_size (str): Numerical str representing state dimensions (ie: 5-15-10)
73        nlayers (int): Number of layers in model
74        out_features (int): Number of out features in model
75        queue (SimpleQueue): SimpleQueue from torch.multiprocessing.get_context() for
76                             saving benchmark run results to
77    """
78    state_size = list(map(int, state_size.split("-")))
79    batch_size = world_size - 2  # No. of observers
80
81    os.environ["MASTER_ADDR"] = master_addr
82    os.environ["MASTER_PORT"] = master_port
83    if rank == 0:
84        rpc.init_rpc(COORDINATOR_NAME, rank=rank, world_size=world_size)
85
86        coordinator = CoordinatorBase(
87            batch_size, batch, state_size, nlayers, out_features
88        )
89        coordinator.run_coordinator(TOTAL_EPISODES, TOTAL_EPISODE_STEPS, queue)
90
91    elif rank == 1:
92        rpc.init_rpc(AGENT_NAME, rank=rank, world_size=world_size)
93    else:
94        rpc.init_rpc(OBSERVER_NAME.format(rank), rank=rank, world_size=world_size)
95    rpc.shutdown()
96
97
98def find_graph_variable(args):
99    r"""
100    Determines if user specified multiple entries for a single argument, in which case
101    benchmark is run for each of these entries.  Comma separated values in a given argument indicate multiple entries.
102    Output is presented so that user can use plot repo to plot the results with each of the
103    variable argument's entries on the x-axis. Args is modified in accordance with this.
104    More than 1 argument with multiple entries is not permitted.
105    Args:
106        args (dict): Dictionary containing arguments passed by the user (and default arguments)
107    """
108    var_types = {
109        "world_size": int,
110        "state_size": str,
111        "nlayers": int,
112        "out_features": int,
113        "batch": str2bool,
114    }
115    for arg in var_types.keys():
116        if "," in args[arg]:
117            if args.get("x_axis_name"):
118                raise ValueError("Only 1 x axis graph variable allowed")
119            args[arg] = list(
120                map(var_types[arg], args[arg].split(","))
121            )  # convert , separated str to list
122            args["x_axis_name"] = arg
123        else:
124            args[arg] = var_types[arg](args[arg])  # convert string to proper type
125
126
127def append_spaces(string, length):
128    r"""
129    Returns a modified string with spaces appended to the end.  If length of string argument
130    is greater than or equal to length, a single space is appended, otherwise x spaces are appended
131    where x is the difference between the length of string and the length argument
132    Args:
133        string (str): String to be modified
134        length (int): Size of desired return string with spaces appended
135    Return: (str)
136    """
137    string = str(string)
138    offset = length - len(string)
139    if offset <= 0:
140        offset = 1
141    string += " " * offset
142    return string
143
144
145def print_benchmark_results(report):
146    r"""
147    Prints benchmark results
148    Args:
149        report (dict): JSON formatted dictionary containing relevant data on the run of this application
150    """
151    print("--------------------------------------------------------------")
152    print("PyTorch distributed rpc benchmark reinforcement learning suite")
153    print("--------------------------------------------------------------")
154    for key, val in report.items():
155        if key != "benchmark_results":
156            print(f"{key} : {val}")
157
158    x_axis_name = report.get("x_axis_name")
159    col_width = 7
160    heading = ""
161    if x_axis_name:
162        x_axis_output_label = f"{x_axis_name} |"
163        heading += append_spaces(x_axis_output_label, col_width)
164    metric_headers = [
165        "agent latency (seconds)",
166        "agent throughput",
167        "observer latency (seconds)",
168        "observer throughput",
169    ]
170    percentile_subheaders = ["p50", "p75", "p90", "p95"]
171    subheading = ""
172    if x_axis_name:
173        subheading += append_spaces(" " * (len(x_axis_output_label) - 1), col_width)
174    for header in metric_headers:
175        heading += append_spaces(header, col_width * len(percentile_subheaders))
176        for percentile in percentile_subheaders:
177            subheading += append_spaces(percentile, col_width)
178    print(heading)
179    print(subheading)
180
181    for benchmark_run in report["benchmark_results"]:
182        run_results = ""
183        if x_axis_name:
184            run_results += append_spaces(
185                benchmark_run[x_axis_name], max(col_width, len(x_axis_output_label))
186            )
187        for metric_name in metric_headers:
188            percentile_results = benchmark_run[metric_name]
189            for percentile in percentile_subheaders:
190                run_results += append_spaces(percentile_results[percentile], col_width)
191        print(run_results)
192
193
194def main():
195    r"""
196    Runs rpc benchmark once if no argument has multiple entries, and otherwise once for each of the multiple entries.
197    Multiple entries is indicated by comma separated values, and may only be done for a single argument.
198    Results are printed as well as saved to output file.  In case of multiple entries for a single argument,
199    the plot repo can be used to benchmark results on the y axis with each entry on the x axis.
200    """
201    find_graph_variable(args)
202
203    # run once if no x axis variables
204    x_axis_variables = args[args["x_axis_name"]] if args.get("x_axis_name") else [None]
205    ctx = mp.get_context("spawn")
206    queue = ctx.SimpleQueue()
207    benchmark_runs = []
208    for i, x_axis_variable in enumerate(
209        x_axis_variables
210    ):  # run benchmark for every x axis variable
211        if len(x_axis_variables) > 1:
212            args[
213                args["x_axis_name"]
214            ] = x_axis_variable  # set x axis variable for this benchmark iteration
215        processes = []
216        start_time = time.time()
217        for rank in range(args["world_size"]):
218            prc = ctx.Process(
219                target=run_worker,
220                args=(
221                    rank,
222                    args["world_size"],
223                    args["master_addr"],
224                    args["master_port"],
225                    args["batch"],
226                    args["state_size"],
227                    args["nlayers"],
228                    args["out_features"],
229                    queue,
230                ),
231            )
232            prc.start()
233            processes.append(prc)
234        benchmark_run_results = queue.get()
235        for process in processes:
236            process.join()
237        print(f"Time taken benchmark run {i} -, {time.time() - start_time}")
238        if args.get("x_axis_name"):
239            # save x axis value was for this iteration in the results
240            benchmark_run_results[args["x_axis_name"]] = x_axis_variable
241        benchmark_runs.append(benchmark_run_results)
242
243    report = args
244    report["benchmark_results"] = benchmark_runs
245    if args.get("x_axis_name"):
246        # x_axis_name was variable so dont save a constant in the report for that variable
247        del report[args["x_axis_name"]]
248    with open(args["output_file_path"], "w") as f:
249        json.dump(report, f)
250    print_benchmark_results(report)
251
252
253if __name__ == "__main__":
254    main()
255