xref: /aosp_15_r20/external/pytorch/third_party/gloo.BUILD (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1load("@rules_cc//cc:defs.bzl", "cc_library")
2load("@pytorch//tools/rules:cu.bzl", "cu_library")
3load("@pytorch//third_party:substitution.bzl", "template_rule")
4load("@pytorch//tools/config:defs.bzl", "if_cuda")
5
6template_rule(
7    name = "gloo_config_cmake_macros",
8    src = "gloo/config.h.in",
9    out = "gloo/config.h",
10    substitutions = {
11        "@GLOO_VERSION_MAJOR@": "0",
12        "@GLOO_VERSION_MINOR@": "5",
13        "@GLOO_VERSION_PATCH@": "0",
14        "cmakedefine01 GLOO_USE_CUDA": "define GLOO_USE_CUDA 1",
15        "cmakedefine01 GLOO_USE_NCCL": "define GLOO_USE_NCCL 0",
16        "cmakedefine01 GLOO_USE_ROCM": "define GLOO_USE_ROCM 0",
17        "cmakedefine01 GLOO_USE_RCCL": "define GLOO_USE_RCCL 0",
18        "cmakedefine01 GLOO_USE_REDIS": "define GLOO_USE_REDIS 0",
19        "cmakedefine01 GLOO_USE_IBVERBS": "define GLOO_USE_IBVERBS 0",
20        "cmakedefine01 GLOO_USE_MPI": "define GLOO_USE_MPI 0",
21        "cmakedefine01 GLOO_USE_AVX": "define GLOO_USE_AVX 0",
22        "cmakedefine01 GLOO_USE_LIBUV": "define GLOO_USE_LIBUV 0",
23        # The `GLOO_HAVE_TRANSPORT_TCP_TLS` line should go above the `GLOO_HAVE_TRANSPORT_TCP` in order to properly substitute the template.
24        "cmakedefine01 GLOO_HAVE_TRANSPORT_TCP_TLS": "define GLOO_HAVE_TRANSPORT_TCP_TLS 1",
25        "cmakedefine01 GLOO_HAVE_TRANSPORT_TCP": "define GLOO_HAVE_TRANSPORT_TCP 1",
26        "cmakedefine01 GLOO_HAVE_TRANSPORT_IBVERBS": "define GLOO_HAVE_TRANSPORT_IBVERBS 0",
27        "cmakedefine01 GLOO_HAVE_TRANSPORT_UV": "define GLOO_HAVE_TRANSPORT_UV 0",
28    },
29)
30
31cc_library(
32    name = "gloo_headers",
33    hdrs = glob(
34        [
35            "gloo/*.h",
36            "gloo/common/*.h",
37            "gloo/rendezvous/*.h",
38            "gloo/transport/*.h",
39            "gloo/transport/tcp/*.h",
40            "gloo/transport/tcp/tls/*.h",
41        ],
42        exclude = [
43            "gloo/rendezvous/redis_store.h",
44        ],
45    ) + ["gloo/config.h"],
46    includes = [
47        ".",
48    ],
49)
50
51cu_library(
52    name = "gloo_cuda",
53    srcs = [
54        "gloo/cuda.cu",
55        "gloo/cuda_private.cu",
56    ],
57    visibility = ["//visibility:public"],
58    deps = [
59        ":gloo_headers",
60    ],
61    alwayslink = True,
62)
63
64cc_library(
65    name = "gloo",
66    srcs = glob(
67        [
68            "gloo/*.cc",
69            "gloo/common/*.cc",
70            "gloo/rendezvous/*.cc",
71            "gloo/transport/*.cc",
72            "gloo/transport/tcp/*.cc",
73        ],
74        exclude = [
75            "gloo/cuda*.cc",
76            "gloo/common/win.cc",
77            "gloo/rendezvous/redis_store.cc",
78        ]
79    ) + if_cuda(glob(["gloo/cuda*.cc"])),
80    copts = [
81        "-std=c++17",
82    ],
83    visibility = ["//visibility:public"],
84    deps = [":gloo_headers"] + if_cuda(
85        [":gloo_cuda"],
86        [],
87    ),
88)
89