1#!/usr/bin/env python3 2# Owner(s): ["oncall: r2p"] 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9 10import argparse 11import os 12 13import torch 14import torch.distributed as dist 15import torch.nn.functional as F 16 17 18def parse_args(): 19 parser = argparse.ArgumentParser(description="test script") 20 21 parser.add_argument( 22 "--init-method", 23 "--init_method", 24 type=str, 25 required=True, 26 help="init_method to pass to `dist.init_process_group()` (e.g. env://)", 27 ) 28 parser.add_argument( 29 "--world-size", 30 "--world_size", 31 type=int, 32 default=os.getenv("WORLD_SIZE", -1), 33 help="world_size to pass to `dist.init_process_group()`", 34 ) 35 parser.add_argument( 36 "--rank", 37 type=int, 38 default=os.getenv("RANK", -1), 39 help="rank to pass to `dist.init_process_group()`", 40 ) 41 42 return parser.parse_args() 43 44 45def main(): 46 args = parse_args() 47 48 dist.init_process_group( 49 backend="gloo", 50 init_method=args.init_method, 51 world_size=args.world_size, 52 rank=args.rank, 53 ) 54 55 rank = dist.get_rank() 56 world_size = dist.get_world_size() 57 58 # one hot (by rank) tensor of size world_size 59 # example: 60 # rank 0, world_size 4 => [1, 0, 0, 0] 61 # rank 1, world_size 4 => [0, 1, 0, 0] 62 # ... 63 t = F.one_hot(torch.tensor(rank), num_classes=world_size) 64 65 # after all_reduce t = tensor.ones(size=world_size) 66 dist.all_reduce(t) 67 68 # adding all elements in t should equal world_size 69 derived_world_size = torch.sum(t).item() 70 if derived_world_size != world_size: 71 raise RuntimeError( 72 f"Wrong world size derived. Expected: {world_size}, Got: {derived_world_size}" 73 ) 74 75 print("Done") 76 77 78if __name__ == "__main__": 79 main() 80