xref: /aosp_15_r20/external/pytorch/tools/rules_cc/cuda_support.patch (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1diff --git cc/private/toolchain/unix_cc_configure.bzl cc/private/toolchain/unix_cc_configure.bzl
2index ba992fc..e4e8364 100644
3--- cc/private/toolchain/unix_cc_configure.bzl
4+++ cc/private/toolchain/unix_cc_configure.bzl
5@@ -27,6 +27,7 @@ load(
6     "which",
7     "write_builtin_include_directory_paths",
8 )
9+load("@rules_cuda//cuda:toolchain.bzl", "cuda_compiler_deps")
10
11 def _field(name, value):
12     """Returns properly indented top level crosstool field."""
13@@ -397,7 +398,7 @@ def configure_unix_toolchain(repository_ctx, cpu_value, overriden_tools):
14     cxx_opts = split_escaped(get_env_var(
15         repository_ctx,
16         "BAZEL_CXXOPTS",
17-        "-std=c++0x",
18+        "-std=c++11",
19         False,
20     ), ":")
21
22@@ -463,7 +464,7 @@ def configure_unix_toolchain(repository_ctx, cpu_value, overriden_tools):
23             )),
24             "%{cc_compiler_deps}": get_starlark_list([":builtin_include_directory_paths"] + (
25                 [":cc_wrapper"] if darwin else []
26-            )),
27+            ) + cuda_compiler_deps()),
28             "%{cc_toolchain_identifier}": cc_toolchain_identifier,
29             "%{compile_flags}": get_starlark_list(
30                 [
31diff --git cc/private/toolchain/unix_cc_toolchain_config.bzl cc/private/toolchain/unix_cc_toolchain_config.bzl
32index c3cf3ba..1744eb4 100644
33--- cc/private/toolchain/unix_cc_toolchain_config.bzl
34+++ cc/private/toolchain/unix_cc_toolchain_config.bzl
35@@ -25,6 +25,7 @@ load(
36     "variable_with_value",
37     "with_feature_set",
38 )
39+load("@rules_cuda//cuda:toolchain.bzl", "cuda_toolchain_config")
40
41 all_compile_actions = [
42     ACTION_NAMES.c_compile,
43@@ -580,7 +581,8 @@ def _impl(ctx):
44                 ],
45                 flag_groups = [
46                     flag_group(
47-                        flags = ["-iquote", "%{quote_include_paths}"],
48+                        # -isystem because there is an nvcc thing where it doesn't forward -iquote to host compiler.
49+                        flags = ["-isystem", "%{quote_include_paths}"],
50                         iterate_over = "quote_include_paths",
51                     ),
52                     flag_group(
53@@ -1152,10 +1154,15 @@ def _impl(ctx):
54             unfiltered_compile_flags_feature,
55         ]
56
57+    cuda = cuda_toolchain_config(
58+        cuda_toolchain_info = ctx.attr._cuda_toolchain_info,
59+        compiler_path = ctx.attr.tool_paths["gcc"],
60+    )
61+
62     return cc_common.create_cc_toolchain_config_info(
63         ctx = ctx,
64-        features = features,
65-        action_configs = action_configs,
66+        features = features + cuda.features,
67+        action_configs = action_configs + cuda.action_configs,
68         cxx_builtin_include_directories = ctx.attr.cxx_builtin_include_directories,
69         toolchain_identifier = ctx.attr.toolchain_identifier,
70         host_system_name = ctx.attr.host_system_name,
71@@ -1192,6 +1199,9 @@ cc_toolchain_config = rule(
72         "tool_paths": attr.string_dict(),
73         "toolchain_identifier": attr.string(mandatory = True),
74         "unfiltered_compile_flags": attr.string_list(),
75+        "_cuda_toolchain_info": attr.label(
76+            default = Label("@rules_cuda//cuda:cuda_toolchain_info"),
77+        ),
78     },
79     provides = [CcToolchainConfigInfo],
80 )
81