xref: /aosp_15_r20/external/tensorflow/tensorflow/tools/toolchains/remote/configure.bzl (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1"""Repository rule for remote GPU autoconfiguration.
2
3This rule creates the starlark file
4//tensorflow/tools/toolchains/remote:execution.bzl
5providing the function `gpu_test_tags`.
6
7`gpu_test_tags` will return:
8
9  * `local`: if `REMOTE_GPU_TESTING` is false, allowing CPU tests to run
10    remotely and GPU tests to run locally in the same bazel invocation.
11  * `remote-gpu`: if `REMOTE_GPU_TESTING` is true; this allows rules to
12    set an execution requirement that enables a GPU-enabled remote platform.
13"""
14
15_REMOTE_GPU_TESTING = "REMOTE_GPU_TESTING"
16
17def _flag_enabled(repository_ctx, flag_name):
18    if flag_name not in repository_ctx.os.environ:
19        return False
20    return repository_ctx.os.environ[flag_name].strip() == "1"
21
22def _remote_execution_configure(repository_ctx):
23    # If we do not support remote gpu test execution, mark them as local, so we
24    # can combine remote builds with local gpu tests.
25    gpu_test_tags = "\"local\""
26    if _flag_enabled(repository_ctx, _REMOTE_GPU_TESTING):
27        gpu_test_tags = "\"remote-gpu\""
28    repository_ctx.template(
29        "remote_execution.bzl",
30        Label("//tensorflow/tools/toolchains/remote:execution.bzl.tpl"),
31        {
32            "%{gpu_test_tags}": gpu_test_tags,
33        },
34    )
35    repository_ctx.template(
36        "BUILD",
37        Label("//tensorflow/tools/toolchains/remote:BUILD.tpl"),
38    )
39
40remote_execution_configure = repository_rule(
41    implementation = _remote_execution_configure,
42    environ = [_REMOTE_GPU_TESTING],
43)
44