xref: /aosp_15_r20/external/pytorch/benchmarks/distributed/rpc/parameter_server/launcher.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import json
3import os
4from pathlib import Path
5
6from data import data_map
7from metrics.ProcessedMetricsPrinter import ProcessedMetricsPrinter
8from models import model_map
9from server import server_map
10from trainer import (
11    criterion_map,
12    ddp_hook_map,
13    ddp_model_map,
14    hook_state_map,
15    iteration_step_map,
16    preprocess_data_map,
17    trainer_map,
18)
19
20import torch
21import torch.distributed as c10d
22import torch.distributed.rpc as rpc
23import torch.multiprocessing as mp
24from torch.distributed.rpc import TensorPipeRpcBackendOptions
25from torch.futures import wait_all
26from torch.utils.data import DataLoader
27
28
29def get_name(rank, args):
30    r"""
31    A function that gets the name for the rank
32    argument
33    Args:
34        rank (int): process number in the world
35        args (parser): benchmark configurations
36    """
37    t_count = args.ntrainer + args.ncudatrainer
38    s_count = args.nserver + args.ncudaserver
39    if rank < t_count:
40        return f"trainer{rank}"
41    elif rank < (t_count + s_count):
42        return f"server{rank}"
43    else:
44        return "master"
45
46
47def get_server_rank(args, rank):
48    r"""
49    A function that gets the server rank for
50    the rank argument.
51    Args:
52        args (parser): benchmark configurations
53        rank (int): trainer rank
54    """
55    s_offset = args.ntrainer + args.ncudatrainer
56    tps = args.ntrainer // args.nserver
57    return rank // tps + s_offset
58
59
60def get_cuda_server_rank(args, rank):
61    r"""
62    A function that gets the cudaserver rank for
63    the rank argument.
64    Args:
65        args (parser): benchmark configurations
66        rank (int): trainer rank
67    """
68    s_offset = args.ntrainer + args.ncudatrainer + args.nserver
69    t_index = rank - args.ntrainer
70    ctps = args.ncudatrainer // args.ncudaserver
71    return t_index // ctps + s_offset
72
73
74def get_server_rref(server_rank, args, extra_args):
75    r"""
76    A function that creates a RRef to the server.
77    Args:
78        server_rank (int): process number in the world
79        args (parser): benchmark configurations
80        extra_args (dict): configurations added by the user
81    """
82    server = server_map[args.server]
83    name = get_name(server_rank, args)
84    if extra_args is not None:
85        server_args = extra_args.values()
86    else:
87        server_args = []
88    if server_rank >= args.ntrainer + args.ncudatrainer + args.nserver:
89        trainer_count = args.ncudatrainer / args.ncudaserver
90        use_cuda_rpc = True
91    else:
92        trainer_count = args.ntrainer / args.nserver
93        use_cuda_rpc = False
94    return rpc.remote(
95        name,
96        server,
97        args=(
98            server_rank,
99            trainer_count,
100            use_cuda_rpc,
101            *server_args,
102        ),
103    )
104
105
106def run_trainer(args, extra_args, data, rank, server_rref):
107    r"""
108    A function that runs obtains a trainer instance and calls
109    the train method.
110    Args:
111        args (parser): benchmark configurations
112        extra_args (dict): configurations added by the user
113        data (list): training samples
114        rank (int): process number in the world
115        server_rref (dict): a dictionary containing server RRefs
116    """
117    trainer_class = trainer_map[args.trainer]
118    if extra_args is not None:
119        trainer_args = extra_args.values()
120    else:
121        trainer_args = []
122    trainer_count = args.ntrainer + args.ncudatrainer
123    store = c10d.FileStore(args.filestore, trainer_count)
124    if args.backend == "gloo":
125        process_group = c10d.ProcessGroupGloo(store, rank, trainer_count)
126    elif args.backend == "nccl":
127        process_group = c10d.ProcessGroupNCCL(store, rank, trainer_count)
128    elif args.backend == "multi":
129        process_group = c10d.ProcessGroupNCCL(store, rank, trainer_count)
130        if c10d.is_initialized() is False:
131            c10d.init_process_group(backend="gloo", rank=rank, world_size=trainer_count)
132
133    model = load_model(args)
134    preprocess_data = preprocess_data_map[args.preprocess_data]
135    create_criterion = criterion_map[args.create_criterion]
136    create_ddp_model = ddp_model_map[args.create_ddp_model]
137    iteration_step = iteration_step_map[args.iteration_step]
138    hook_state_class = hook_state_map[args.hook_state]
139    hook = ddp_hook_map[args.ddp_hook]
140    # check if this a cudatrainer
141    use_cuda_rpc = rank >= args.ntrainer
142    trainer = trainer_class(
143        process_group,
144        use_cuda_rpc,
145        server_rref,
146        args.backend,
147        args.epochs,
148        preprocess_data,
149        create_criterion,
150        create_ddp_model,
151        hook_state_class,
152        hook,
153        iteration_step,
154        *trainer_args,
155    )
156    trainer.train(model, data)
157    metrics = trainer.get_metrics()
158    return [rank, metrics]
159
160
161def call_trainers(args, extra_args, train_data, server_rrefs):
162    r"""
163    A function that starts the trainers. Each trainer is started
164    using an rpc_async request.
165    Args:
166        args (parser): benchmark configurations
167        extra_args (dict): configurations added by the user
168        train_data (list): training samples
169        server_rrefs (dict): a dictionary containing server RRefs
170    """
171    futs = []
172    for trainer_rank in range(0, args.ntrainer + args.ncudatrainer):
173        trainer_name = get_name(trainer_rank, args)
174        server_rref = None
175        if server_rrefs:
176            if trainer_rank >= args.ntrainer:
177                server_rank = get_cuda_server_rank(args, trainer_rank)
178            else:
179                server_rank = get_server_rank(args, trainer_rank)
180            server_rref = server_rrefs[server_rank]
181        fut = rpc.rpc_async(
182            trainer_name,
183            run_trainer,
184            args=(
185                args,
186                extra_args,
187                train_data[trainer_rank],
188                trainer_rank,
189                server_rref,
190            ),
191            timeout=args.rpc_timeout,
192        )
193        futs.append(fut)
194    return futs
195
196
197def benchmark_warmup(args, extra_args, data, server_rrefs):
198    r"""
199    A function that runs the training algorithm. The goal of this
200    function is to warm the rpc. The server states are reset.
201    Args:
202        args (parser): benchmark configurations
203        extra_args (dict): configurations added by the user
204        data (list): training samples
205        server_rrefs (dict): a dictionary containing server RRefs
206    """
207    futs = call_trainers(args, extra_args, data, server_rrefs)
208    wait_all(futs)
209    for server_rref in server_rrefs.values():
210        server_rref.rpc_sync().reset_state(server_rref)
211    print("benchmark warmup done\n")
212
213
214def split_list(arr, n):
215    r"""
216    A function that splits a list into n lists
217    Args:
218        arr (list): training samples
219        n (int): number of output lists
220    """
221    return [arr[i::n] for i in range(n)]
222
223
224def get_server_metrics(server_rrefs):
225    r"""
226    A function that calls the remote server to obtain metrics
227    collected during the benchmark run.
228    Args:
229        server_rrefs (dict): a dictionary containing server RRefs
230    """
231    rank_metrics = []
232    for rank, server_rref in server_rrefs.items():
233        metrics = server_rref.rpc_sync().get_metrics(server_rref)
234        rank_metrics.append([rank, metrics])
235    return rank_metrics
236
237
238def run_master(rank, data, args, extra_configs, rpc_backend_options):
239    r"""
240    A function that runs the master process in the world. This function
241    obtains remote references to initialized servers, splits the data,
242    runs the trainers, and prints metrics.
243    Args:
244        rank (int): process number in the world
245        data (list): training samples
246        args (parser): benchmark configurations
247        extra_configs (dict): configurations added by the user
248        rpc_backend_options (rpc): configurations/options for the rpc TODO: fix
249    """
250    world_size = args.ntrainer + args.ncudatrainer + args.nserver + args.ncudaserver + 1
251    rpc.init_rpc(
252        get_name(rank, args),
253        rank=rank,
254        world_size=world_size,
255        rpc_backend_options=rpc_backend_options,
256    )
257    server_rrefs = {}
258    for i in range(args.ntrainer + args.ncudatrainer, world_size - 1):
259        server_rrefs[i] = get_server_rref(i, args, extra_configs["server_config"])
260    train_data = split_list(
261        list(DataLoader(data, batch_size=args.batch_size)),
262        args.ntrainer + args.ncudatrainer,
263    )
264
265    # warmup run the benchmark
266    benchmark_warmup(args, extra_configs["trainer_config"], train_data, server_rrefs)
267    # run the benchmark
268    trainer_futs = call_trainers(
269        args, extra_configs["trainer_config"], train_data, server_rrefs
270    )
271    # collect metrics and print
272    metrics_printer = ProcessedMetricsPrinter()
273    rank_metrics_list = wait_all(trainer_futs)
274    metrics_printer.print_metrics("trainer", rank_metrics_list)
275    rank_metrics_list = get_server_metrics(server_rrefs)
276    metrics_printer.print_metrics("server", rank_metrics_list)
277
278
279def run_benchmark(rank, args, data):
280    r"""
281    A function that runs the benchmark.
282    Args:
283        rank (int): process number in the world
284        args (parser): configuration args
285        data (list): training samples
286    """
287
288    config = load_extra_configs(args)
289
290    torch.manual_seed(args.torch_seed)
291    torch.cuda.manual_seed_all(args.cuda_seed)
292    torch.backends.cudnn.benchmark = True
293    torch.backends.cudnn.deterministic = True
294
295    world_size = args.ntrainer + args.ncudatrainer + args.nserver + args.ncudaserver + 1
296    os.environ["MASTER_ADDR"] = args.master_addr
297    os.environ["MASTER_PORT"] = args.master_port
298    rpc_backend_options = TensorPipeRpcBackendOptions(rpc_timeout=args.rpc_timeout)
299    if rank == world_size - 1:
300        # master = [ntrainer + ncudatrainer + nserver + ncudaserver, ntrainer + ncudatrainer + nserver + ncudaserver]
301        run_master(rank, data, args, config, rpc_backend_options)
302    elif rank >= args.ntrainer + args.ncudatrainer:
303        # parameter_servers = [ntrainer + ncudatrainer, ntrainer + ncudatrainer + nserver + ncudaserver)
304        rpc.init_rpc(
305            get_name(rank, args),
306            rank=rank,
307            world_size=world_size,
308            rpc_backend_options=rpc_backend_options,
309        )
310    else:
311        # trainers = [0, ntrainer + ncudatrainer)
312        if rank >= args.ntrainer:
313            server_rank = get_cuda_server_rank(args, rank)
314            server_name = get_name(server_rank, args)
315            rpc_backend_options.set_device_map(server_name, {rank: server_rank})
316        trainer_name = get_name(rank, args)
317        rpc.init_rpc(
318            trainer_name,
319            rank=rank,
320            world_size=world_size,
321            rpc_backend_options=rpc_backend_options,
322        )
323    rpc.shutdown()
324
325
326def get_json_config(file_name: str, id: str):
327    r"""
328    A function that loads a json configuration from a file.
329    Args:
330        file_name (str): name of configuration file to load
331        id (str): configuration that will be loaded
332    """
333    with open(Path(__file__).parent / file_name) as f:
334        json_config = json.load(f)[id]
335    return json_config
336
337
338def load_extra_configs(args):
339    r"""
340    A function that creates a dictionary that contains any extra configurations
341    set by the user. The dictionary will contain two keys trainer_config and
342    server_config, with default values None.
343    Args:
344        args (parser): launcher configurations
345    """
346    trainer_config_file = args.trainer_config_path
347    server_config_file = args.server_config_path
348    configurations = {"trainer_config": None, "server_config": None}
349    if args.trainer is not None and trainer_config_file is not None:
350        configurations["trainer_config"] = get_json_config(
351            trainer_config_file, args.trainer
352        )
353    if args.server is not None and server_config_file is not None:
354        configurations["server_config"] = get_json_config(
355            server_config_file, args.server
356        )
357    return configurations
358
359
360def load_data(args):
361    r"""
362    A function that creates an instance of the data class.
363    Args:
364        args (parser): launcher configurations
365    """
366    data_config_file = args.data_config_path
367    data_config = get_json_config(data_config_file, args.data)
368    data_class = data_map[data_config["data_class"]]
369    return data_class(**data_config["configurations"])
370
371
372def load_model(args):
373    r"""
374    A function that creates an instance of the model class.
375    Args:
376        args (parser): launcher configurations
377    """
378    model_config_file = args.model_config_path
379    model_config = get_json_config(model_config_file, args.model)
380    model_class = model_map[model_config["model_class"]]
381    return model_class(**model_config["configurations"])
382
383
384def main(args):
385    r"""
386    A function that creates multiple processes to run the benchmark.
387    Args:
388        args (parser): launcher configurations
389    """
390    # CPU and RPC trainer checks
391    if args.ntrainer > 0 and args.ncudatrainer > 0:
392        assert args.nserver > 0 and args.ncudaserver > 0
393    if args.nserver > 0:
394        assert args.ntrainer > 0
395        assert args.ntrainer % args.nserver == 0
396    if args.ncudaserver > 0:
397        assert args.ncudatrainer > 0
398        assert args.ncudatrainer % args.ncudaserver == 0
399
400    world_size = args.ntrainer + args.ncudatrainer + args.nserver + args.ncudaserver + 1
401
402    data = load_data(args)
403
404    mp.spawn(
405        run_benchmark,
406        args=(
407            args,
408            data,
409        ),
410        nprocs=world_size,
411        join=True,
412    )
413
414
415if __name__ == "__main__":
416    parser = argparse.ArgumentParser(description="RPC server Benchmark")
417    parser.add_argument(
418        "--master-addr",
419        "--master_addr",
420        type=str,
421        help="IP address of the machine that will host the process with rank 0",
422    )
423    parser.add_argument(
424        "--master-port",
425        "--master_port",
426        type=str,
427        help="A free port on the machine that will host the process with rank 0",
428    )
429    parser.add_argument(
430        "--trainer",
431        type=str,
432        help="trainer map key to get trainer class for benchmark run",
433    )
434    parser.add_argument("--ntrainer", type=int, help="trainer count for benchmark run")
435    parser.add_argument(
436        "--ncudatrainer", type=int, help="cudatrainer count for benchmark run"
437    )
438    parser.add_argument(
439        "--filestore", type=str, help="filestore location for process group"
440    )
441    parser.add_argument(
442        "--server",
443        type=str,
444        help="server map key to get trainer class for benchmark run",
445    )
446    parser.add_argument("--nserver", type=int, help="server count for benchmark run")
447    parser.add_argument(
448        "--ncudaserver", type=int, help="cudaserver count for benchmark run"
449    )
450    parser.add_argument(
451        "--rpc-timeout",
452        "--rpc_timeout",
453        type=int,
454        help="timeout in seconds to use for RPC",
455    )
456    parser.add_argument(
457        "--backend",
458        type=str,
459        help="distributed communication backend to use for benchmark run",
460    )
461    parser.add_argument("--epochs", type=int, help="epoch count for training")
462    parser.add_argument(
463        "--batch-size",
464        "--batch_size",
465        type=int,
466        help="number of training examples used in one iteration",
467    )
468    parser.add_argument("--data", type=str, help="id for data configuration")
469    parser.add_argument("--model", type=str, help="id for model configuration")
470    parser.add_argument(
471        "--data-config-path",
472        "--data_config_path",
473        type=str,
474        help="path to data configuration file",
475    )
476    parser.add_argument(
477        "--model-config-path",
478        "--model_config_path",
479        type=str,
480        help="path to model configuration file",
481    )
482    parser.add_argument(
483        "--server-config-path",
484        "--server_config_path",
485        type=str,
486        help="path to server configuration file",
487    )
488    parser.add_argument(
489        "--trainer-config-path",
490        "--trainer_config_path",
491        type=str,
492        help="path to trainer configuration file",
493    )
494    parser.add_argument(
495        "--torch-seed",
496        "--torch_seed",
497        type=int,
498        help="seed for generating random numbers to a non-deterministic random number",
499    )
500    parser.add_argument(
501        "--cuda-seed",
502        "--cuda_seed",
503        type=int,
504        help="seed for generating random numbers to a random number for the current GPU",
505    )
506    parser.add_argument(
507        "--preprocess-data",
508        "--preprocess_data",
509        type=str,
510        help="this function will be used to preprocess data before training",
511    )
512    parser.add_argument(
513        "--create-criterion",
514        "--create_criterion",
515        type=str,
516        help="this function will be used to create the criterion used for model loss calculation",
517    )
518    parser.add_argument(
519        "--create-ddp-model",
520        "--create_ddp_model",
521        type=str,
522        help="this function will be used to create the ddp model used during training",
523    )
524    parser.add_argument(
525        "--hook-state",
526        "--hook_state",
527        type=str,
528        help="this will be the state class used when registering the ddp communication hook",
529    )
530    parser.add_argument(
531        "--ddp-hook",
532        "--ddp_hook",
533        type=str,
534        default="allreduce_hook",
535        help="ddp communication hook",
536    )
537    parser.add_argument(
538        "--iteration-step",
539        "--iteration_step",
540        type=str,
541        help="this will be the function called for each iteration of training",
542    )
543    args = parser.parse_args()
544    print(f"{args}\n")
545    main(args)
546