1#!/usr/bin/env python3 2# Owner(s): ["oncall: distributed"] 3 4import contextlib 5import copyreg 6import os 7import sys 8 9import torch 10import torch.distributed as dist 11 12 13if not dist.is_available(): 14 print("Distributed not available, skipping tests", file=sys.stderr) 15 sys.exit(0) 16 17import torch.distributed.rpc as rpc 18import torch.multiprocessing.reductions as TorchMpReductions 19from torch import multiprocessing 20from torch.distributed.rpc.api import _use_rpc_pickler 21from torch.distributed.rpc.internal import _InternalRPCPickler 22from torch.testing._internal.common_utils import run_tests, TestCase 23 24 25@contextlib.contextmanager 26def fs_sharing(): 27 prev_strategy = multiprocessing.get_sharing_strategy() 28 multiprocessing.set_sharing_strategy("file_system") 29 try: 30 yield 31 finally: 32 multiprocessing.set_sharing_strategy(prev_strategy) 33 34 35class ShareMemoryRPCPickler(_InternalRPCPickler): 36 def __init__(self) -> None: 37 super().__init__() 38 self._dispatch_table 39 # pyre-fixme[4]: Attribute must be annotated. 40 self._dispatch_table = copyreg.dispatch_table.copy() 41 42 for t in torch._storage_classes: 43 self._dispatch_table[t] = TorchMpReductions.reduce_storage 44 45 for t in torch._tensor_classes: 46 self._dispatch_table[t] = TorchMpReductions.reduce_tensor 47 self._dispatch_table[torch.Tensor] = TorchMpReductions.reduce_tensor 48 self._dispatch_table[ 49 torch.nn.parameter.Parameter 50 ] = TorchMpReductions.reduce_tensor 51 52 53def worker_loop(a): 54 rpc.init_rpc("worker1", rank=1, world_size=2) 55 rpc.shutdown() 56 57 58def worker_fn(m): 59 pass 60 61 62class TestRPCPickler(TestCase): 63 def test_case(self): 64 os.environ["MASTER_ADDR"] = "localhost" 65 os.environ["MASTER_PORT"] = "29500" 66 67 with fs_sharing(): 68 r = multiprocessing.spawn(worker_loop, join=False) 69 70 try: 71 with _use_rpc_pickler(ShareMemoryRPCPickler()): 72 rpc.init_rpc("worker0", rank=0, world_size=2) 73 m = torch.nn.Linear(1, 2) 74 m.share_memory() 75 rref = rpc.remote("worker1", worker_fn, args=(m,)) 76 77 rref.to_here() 78 finally: 79 rpc.shutdown() 80 r.join() 81 82 83if __name__ == "__main__": 84 run_tests() 85