1"""Definitions for using tools like saved_model_cli.""" 2 3load("//tensorflow:tensorflow.bzl", "clean_dep", "if_xla_available") 4load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu") 5load("//tensorflow/compiler/aot:tfcompile.bzl", "target_llvm_triple") 6 7def _maybe_force_compile(args, force_compile): 8 if force_compile: 9 return args 10 else: 11 return if_xla_available(args) 12 13def saved_model_compile_aot( 14 name, 15 directory, 16 filegroups, 17 cpp_class, 18 checkpoint_path = None, 19 tag_set = "serve", 20 signature_def = "serving_default", 21 variables_to_feed = "", 22 target_triple = None, 23 target_cpu = None, 24 multithreading = False, 25 force_without_xla_support_flag = True, 26 tags = None): 27 """Compile a SavedModel directory accessible from a filegroup. 28 29 This target rule takes a path to a filegroup directory containing a 30 SavedModel and generates a cc_library with an AOT compiled model. 31 For extra details, see the help for saved_model_cli's aot_compile_cpu help. 32 33 **NOTE** Any variables passed to `variables_to_feed` *must be set by the 34 user*. These variables will NOT be frozen and their values will be 35 uninitialized in the compiled object (this applies to all input 36 arguments from the signature as well). 37 38 Example usage: 39 40 ``` 41 saved_model_compile_aot( 42 name = "aot_compiled_x_plus_y", 43 cpp_class = "tensorflow::CompiledModel", 44 directory = "//tensorflow/cc/saved_model:testdata/x_plus_y_v2_debuginfo", 45 filegroups = [ 46 "//tensorflow/cc/saved_model:saved_model_half_plus_two", 47 ] 48 ) 49 50 cc_test( 51 name = "test", 52 srcs = ["test.cc"], 53 deps = [ 54 "//tensorflow/core:test_main", 55 ":aot_compiled_x_plus_y", 56 "//tensorflow/core:test", 57 "//tensorflow/core/platform:logging", 58 ]), 59 ) 60 61 In "test.cc": 62 63 #include "third_party/tensorflow/python/tools/aot_compiled_x_plus_y.h" 64 65 TEST(Test, Run) { 66 tensorflow::CompiledModel model; 67 CHECK(model.Run()); 68 } 69 ``` 70 71 Args: 72 name: The rule name, and the name prefix of the headers and object file 73 emitted by this rule. 74 directory: The bazel directory containing saved_model.pb and variables/ 75 subdirectories. 76 filegroups: List of `filegroup` targets; these filegroups contain the 77 files pointed to by `directory` and `checkpoint_path`. 78 cpp_class: The name of the C++ class that will be generated, including 79 namespace; e.g. "my_model::InferenceRunner". 80 checkpoint_path: The bazel directory containing `variables.index`. If 81 not provided, then `$directory/variables/` is used 82 (default for SavedModels). 83 tag_set: The tag set to use in the SavedModel. 84 signature_def: The name of the signature to use from the SavedModel. 85 variables_to_feed: (optional) The names of the variables to feed, a comma 86 separated string, or 'all'. If empty, all variables will be frozen and none 87 may be fed at runtime. 88 89 **NOTE** Any variables passed to `variables_to_feed` *must be set by 90 the user*. These variables will NOT be frozen and their values will be 91 uninitialized in the compiled object (this applies to all input 92 arguments from the signature as well). 93 target_triple: The LLVM target triple to use (defaults to current build's 94 target architecture's triple). Similar to clang's -target flag. 95 target_cpu: The LLVM cpu name used for compilation. Similar to clang's 96 -mcpu flag. 97 multithreading: Whether to compile multithreaded AOT code. 98 Note, this increases the set of dependencies for binaries using 99 the AOT library at both build and runtime. For example, 100 the resulting object files may have external dependencies on 101 multithreading libraries like nsync. 102 force_without_xla_support_flag: Whether to compile even when 103 `--define=with_xla_support=true` is not set. If `False`, and the 104 define is not passed when building, then the created `cc_library` 105 will be empty. In this case, downstream targets should 106 conditionally build using macro `tfcompile.bzl:if_xla_available`. 107 This flag is used by the TensorFlow build to avoid building on 108 architectures that do not support XLA. 109 tags: List of target tags. 110 """ 111 saved_model = "{}/saved_model.pb".format(directory) 112 target_triple = target_triple or target_llvm_triple() 113 target_cpu = target_cpu or tfcompile_target_cpu() or "" 114 variables_to_feed = variables_to_feed or "''" 115 if checkpoint_path: 116 checkpoint_cmd_args = ( 117 "--checkpoint_path \"$$(dirname $(location {}/variables.index))\" " 118 .format(checkpoint_path) 119 ) 120 checkpoint_srcs = ["{}/variables.index".format(checkpoint_path)] 121 else: 122 checkpoint_cmd_args = "" 123 checkpoint_srcs = [] 124 125 native.genrule( 126 name = "{}_gen".format(name), 127 srcs = filegroups + [saved_model] + checkpoint_srcs, 128 outs = [ 129 "{}.h".format(name), 130 "{}.o".format(name), 131 "{}_metadata.o".format(name), 132 "{}_makefile.inc".format(name), 133 ], 134 cmd = ( 135 "$(location {}) aot_compile_cpu ".format( 136 clean_dep("//tensorflow/python/tools:saved_model_cli"), 137 ) + 138 "--dir \"$$(dirname $(location {}))\" ".format(saved_model) + 139 checkpoint_cmd_args + 140 "--output_prefix $(@D)/{} ".format(name) + 141 "--cpp_class {} ".format(cpp_class) + 142 "--variables_to_feed {} ".format(variables_to_feed) + 143 "--signature_def_key {} ".format(signature_def) + 144 "--multithreading {} ".format(multithreading) + 145 "--target_triple " + target_triple + " " + 146 ("--target_cpu " + target_cpu + " " if target_cpu else "") + 147 "--tag_set {} ".format(tag_set) 148 ), 149 tags = tags, 150 tools = [ 151 "//tensorflow/python/tools:saved_model_cli", 152 ], 153 ) 154 155 native.cc_library( 156 name = name, 157 srcs = _maybe_force_compile( 158 [ 159 ":{}.o".format(name), 160 ":{}_metadata.o".format(name), 161 ], 162 force_compile = force_without_xla_support_flag, 163 ), 164 hdrs = _maybe_force_compile( 165 [ 166 ":{}.h".format(name), 167 ], 168 force_compile = force_without_xla_support_flag, 169 ), 170 tags = tags, 171 deps = _maybe_force_compile( 172 [ 173 "//tensorflow/compiler/tf2xla:xla_compiled_cpu_runtime_standalone", 174 ], 175 force_compile = force_without_xla_support_flag, 176 ), 177 ) 178