xref: /aosp_15_r20/external/pytorch/test/distributed/rpc/cuda/test_tensorpipe_agent.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Owner(s): ["oncall: distributed"]
3
4import sys
5
6import torch.distributed as dist
7
8
9if not dist.is_available():
10    print("Distributed not available, skipping tests", file=sys.stderr)
11    sys.exit(0)
12
13import torch
14from torch.testing._internal.common_utils import run_tests
15from torch.testing._internal.distributed.rpc.tensorpipe_rpc_agent_test_fixture import (
16    TensorPipeRpcAgentTestFixture,
17)
18from torch.testing._internal.distributed.rpc_utils import (
19    generate_tests,
20    GENERIC_CUDA_TESTS,
21    TENSORPIPE_CUDA_TESTS,
22)
23
24
25if torch.cuda.is_available():
26    torch.cuda.memory._set_allocator_settings("expandable_segments:False")
27
28globals().update(
29    generate_tests(
30        "TensorPipe",
31        TensorPipeRpcAgentTestFixture,
32        GENERIC_CUDA_TESTS + TENSORPIPE_CUDA_TESTS,
33        __name__,
34    )
35)
36
37
38if __name__ == "__main__":
39    run_tests()
40