xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tools/tools.bzl (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1"""Definitions for using tools like saved_model_cli."""
2
3load("//tensorflow:tensorflow.bzl", "clean_dep", "if_xla_available")
4load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu")
5load("//tensorflow/compiler/aot:tfcompile.bzl", "target_llvm_triple")
6
7def _maybe_force_compile(args, force_compile):
8    if force_compile:
9        return args
10    else:
11        return if_xla_available(args)
12
13def saved_model_compile_aot(
14        name,
15        directory,
16        filegroups,
17        cpp_class,
18        checkpoint_path = None,
19        tag_set = "serve",
20        signature_def = "serving_default",
21        variables_to_feed = "",
22        target_triple = None,
23        target_cpu = None,
24        multithreading = False,
25        force_without_xla_support_flag = True,
26        tags = None):
27    """Compile a SavedModel directory accessible from a filegroup.
28
29    This target rule takes a path to a filegroup directory containing a
30    SavedModel and generates a cc_library with an AOT compiled model.
31    For extra details, see the help for saved_model_cli's aot_compile_cpu help.
32
33    **NOTE** Any variables passed to `variables_to_feed` *must be set by the
34    user*.  These variables will NOT be frozen and their values will be
35    uninitialized in the compiled object (this applies to all input
36    arguments from the signature as well).
37
38    Example usage:
39
40    ```
41    saved_model_compile_aot(
42      name = "aot_compiled_x_plus_y",
43      cpp_class = "tensorflow::CompiledModel",
44      directory = "//tensorflow/cc/saved_model:testdata/x_plus_y_v2_debuginfo",
45      filegroups = [
46          "//tensorflow/cc/saved_model:saved_model_half_plus_two",
47      ]
48    )
49
50    cc_test(
51      name = "test",
52      srcs = ["test.cc"],
53      deps = [
54        "//tensorflow/core:test_main",
55        ":aot_compiled_x_plus_y",
56        "//tensorflow/core:test",
57        "//tensorflow/core/platform:logging",
58      ]),
59    )
60
61    In "test.cc":
62
63    #include "third_party/tensorflow/python/tools/aot_compiled_x_plus_y.h"
64
65    TEST(Test, Run) {
66      tensorflow::CompiledModel model;
67      CHECK(model.Run());
68    }
69    ```
70
71    Args:
72      name: The rule name, and the name prefix of the headers and object file
73        emitted by this rule.
74      directory: The bazel directory containing saved_model.pb and variables/
75        subdirectories.
76      filegroups: List of `filegroup` targets; these filegroups contain the
77        files pointed to by `directory` and `checkpoint_path`.
78      cpp_class: The name of the C++ class that will be generated, including
79        namespace; e.g. "my_model::InferenceRunner".
80      checkpoint_path: The bazel directory containing `variables.index`.  If
81        not provided, then `$directory/variables/` is used
82        (default for SavedModels).
83      tag_set: The tag set to use in the SavedModel.
84      signature_def: The name of the signature to use from the SavedModel.
85      variables_to_feed: (optional) The names of the variables to feed, a comma
86        separated string, or 'all'.  If empty, all variables will be frozen and none
87        may be fed at runtime.
88
89        **NOTE** Any variables passed to `variables_to_feed` *must be set by
90        the user*.  These variables will NOT be frozen and their values will be
91        uninitialized in the compiled object (this applies to all input
92        arguments from the signature as well).
93      target_triple: The LLVM target triple to use (defaults to current build's
94        target architecture's triple).  Similar to clang's -target flag.
95      target_cpu: The LLVM cpu name used for compilation.  Similar to clang's
96        -mcpu flag.
97      multithreading: Whether to compile multithreaded AOT code.
98        Note, this increases the set of dependencies for binaries using
99        the AOT library at both build and runtime.  For example,
100        the resulting object files may have external dependencies on
101        multithreading libraries like nsync.
102      force_without_xla_support_flag: Whether to compile even when
103        `--define=with_xla_support=true` is not set.  If `False`, and the
104        define is not passed when building, then the created `cc_library`
105        will be empty.  In this case, downstream targets should
106        conditionally build using macro `tfcompile.bzl:if_xla_available`.
107        This flag is used by the TensorFlow build to avoid building on
108        architectures that do not support XLA.
109      tags: List of target tags.
110    """
111    saved_model = "{}/saved_model.pb".format(directory)
112    target_triple = target_triple or target_llvm_triple()
113    target_cpu = target_cpu or tfcompile_target_cpu() or ""
114    variables_to_feed = variables_to_feed or "''"
115    if checkpoint_path:
116        checkpoint_cmd_args = (
117            "--checkpoint_path \"$$(dirname $(location {}/variables.index))\" "
118                .format(checkpoint_path)
119        )
120        checkpoint_srcs = ["{}/variables.index".format(checkpoint_path)]
121    else:
122        checkpoint_cmd_args = ""
123        checkpoint_srcs = []
124
125    native.genrule(
126        name = "{}_gen".format(name),
127        srcs = filegroups + [saved_model] + checkpoint_srcs,
128        outs = [
129            "{}.h".format(name),
130            "{}.o".format(name),
131            "{}_metadata.o".format(name),
132            "{}_makefile.inc".format(name),
133        ],
134        cmd = (
135            "$(location {}) aot_compile_cpu ".format(
136                clean_dep("//tensorflow/python/tools:saved_model_cli"),
137            ) +
138            "--dir \"$$(dirname $(location {}))\" ".format(saved_model) +
139            checkpoint_cmd_args +
140            "--output_prefix $(@D)/{} ".format(name) +
141            "--cpp_class {} ".format(cpp_class) +
142            "--variables_to_feed {} ".format(variables_to_feed) +
143            "--signature_def_key {} ".format(signature_def) +
144            "--multithreading {} ".format(multithreading) +
145            "--target_triple " + target_triple + " " +
146            ("--target_cpu " + target_cpu + " " if target_cpu else "") +
147            "--tag_set {} ".format(tag_set)
148        ),
149        tags = tags,
150        tools = [
151            "//tensorflow/python/tools:saved_model_cli",
152        ],
153    )
154
155    native.cc_library(
156        name = name,
157        srcs = _maybe_force_compile(
158            [
159                ":{}.o".format(name),
160                ":{}_metadata.o".format(name),
161            ],
162            force_compile = force_without_xla_support_flag,
163        ),
164        hdrs = _maybe_force_compile(
165            [
166                ":{}.h".format(name),
167            ],
168            force_compile = force_without_xla_support_flag,
169        ),
170        tags = tags,
171        deps = _maybe_force_compile(
172            [
173                "//tensorflow/compiler/tf2xla:xla_compiled_cpu_runtime_standalone",
174            ],
175            force_compile = force_without_xla_support_flag,
176        ),
177    )
178