xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/xla_compilation_cache.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_compilation_cache.h"
17 
18 #include <memory>
19 #include <numeric>
20 #include <string>
21 #include <utility>
22 #include <variant>
23 
24 #include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h"
25 #include "absl/base/call_once.h"
26 #include "absl/status/status.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_join.h"
29 #include "absl/types/variant.h"
30 #include "tensorflow/compiler/jit/flags.h"
31 #include "tensorflow/compiler/jit/xla_activity.pb.h"
32 #include "tensorflow/compiler/jit/xla_activity_listener.h"
33 #include "tensorflow/compiler/jit/xla_cluster_util.h"
34 #include "tensorflow/compiler/jit/xla_compilation_cache.pb.h"
35 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
36 #include "tensorflow/compiler/mlir/utils/array_container_utils.h"
37 #include "tensorflow/compiler/tf2xla/shape_util.h"
38 #include "tensorflow/compiler/tf2xla/type_util.h"
39 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
40 #include "tensorflow/compiler/tf2xla/xla_context.h"
41 #include "tensorflow/compiler/xla/client/client_library.h"
42 #include "tensorflow/compiler/xla/protobuf_util.h"
43 #include "tensorflow/compiler/xla/service/compiler.h"
44 #include "tensorflow/compiler/xla/service/hlo.pb.h"
45 #include "tensorflow/compiler/xla/status_macros.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/core/common_runtime/device.h"
48 #include "tensorflow/core/common_runtime/function.h"
49 #include "tensorflow/core/common_runtime/graph_constructor.h"
50 #include "tensorflow/core/common_runtime/graph_optimizer.h"
51 #include "tensorflow/core/framework/attr_value_util.h"
52 #include "tensorflow/core/framework/graph.pb.h"
53 #include "tensorflow/core/framework/metrics.h"
54 #include "tensorflow/core/framework/op_kernel.h"
55 #include "tensorflow/core/framework/types.h"
56 #include "tensorflow/core/graph/algorithm.h"
57 #include "tensorflow/core/graph/node_builder.h"
58 #include "tensorflow/core/lib/hash/hash.h"
59 #include "tensorflow/core/lib/strings/proto_serialization.h"
60 #include "tensorflow/core/platform/env.h"
61 #include "tensorflow/core/platform/errors.h"
62 #include "tensorflow/core/platform/fingerprint.h"
63 #include "tensorflow/core/platform/logging.h"
64 #include "tensorflow/core/platform/path.h"
65 #include "tensorflow/core/platform/status.h"
66 #include "tensorflow/core/platform/statusor.h"
67 #include "tensorflow/core/protobuf/debug_event.pb.h"
68 #include "tensorflow/core/protobuf/error_codes.pb.h"
69 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
70 #include "tensorflow/core/public/version.h"
71 #include "tensorflow/core/tpu/tpu_defs.h"
72 #include "tensorflow/core/util/determinism.h"
73 #include "tensorflow/core/util/dump_graph.h"
74 
75 namespace tensorflow {
76 namespace {
77 
78 using TensorTypeAndShape = XlaCompilationCache::Signature::TensorTypeAndShape;
79 
80 constexpr char kXlaSerializedCacheKeySeparator[] = "__";
81 
82 // Functor that converts a Signature's arg to a human readable string.
83 struct SignatureHumanStringAppender {
SignatureHumanStringAppendertensorflow::__anon9b0cc7d50111::SignatureHumanStringAppender84   explicit SignatureHumanStringAppender(string* dest) : dest(dest) {}
85   string* dest;
operator ()tensorflow::__anon9b0cc7d50111::SignatureHumanStringAppender86   void operator()(const Tensor& arg) {
87     absl::StrAppend(dest, "; ", arg.DebugString());
88   }
operator ()tensorflow::__anon9b0cc7d50111::SignatureHumanStringAppender89   void operator()(const TensorTypeAndShape& arg) {
90     absl::StrAppend(dest, ",", DataTypeString(arg.first));
91     absl::StrAppend(dest, " [", absl::StrJoin(arg.second, ","), "]");
92   }
93 };
94 
95 // Functor that compares the arg values of two different signatures. Returns
96 // true when the args are not equal.
97 struct SignatureNotEqual {
operator ()tensorflow::__anon9b0cc7d50111::SignatureNotEqual98   bool operator()(const Tensor& arg, const Tensor& other) {
99     return arg.dtype() != other.dtype() || arg.shape() != other.shape() ||
100            arg.tensor_data() != other.tensor_data();
101   }
operator ()tensorflow::__anon9b0cc7d50111::SignatureNotEqual102   bool operator()(const TensorTypeAndShape& arg,
103                   const TensorTypeAndShape& other) {
104     return arg.first != other.first || arg.second != other.second;
105   }
operator ()tensorflow::__anon9b0cc7d50111::SignatureNotEqual106   bool operator()(const Tensor& arg, const TensorTypeAndShape& other) {
107     return true;
108   }
operator ()tensorflow::__anon9b0cc7d50111::SignatureNotEqual109   bool operator()(const TensorTypeAndShape& arg, const Tensor& other) {
110     return true;
111   }
112 };
113 
114 // Functor that incrementally computes a Signature's hash given its current hash
115 // and one of its args.
116 struct SignatureHashCombiner {
SignatureHashCombinertensorflow::__anon9b0cc7d50111::SignatureHashCombiner117   explicit SignatureHashCombiner(const uint64 h) : h(h) {}
118   uint64 h;
operator ()tensorflow::__anon9b0cc7d50111::SignatureHashCombiner119   uint64 operator()(const Tensor& arg) {
120     h = Hash64Combine(h, std::hash<int>()(static_cast<int>(arg.dtype())));
121     h = Hash64Combine(
122         h, Hash64(arg.tensor_data().data(), arg.tensor_data().size()));
123     for (int dim = 0; dim < arg.dims(); ++dim) {
124       h = Hash64Combine(h, std::hash<int>()(arg.dim_size(dim)));
125     }
126     return h;
127   }
operator ()tensorflow::__anon9b0cc7d50111::SignatureHashCombiner128   uint64 operator()(const TensorTypeAndShape& arg) {
129     h = Hash64Combine(h, std::hash<int>()(static_cast<int>(arg.first)));
130     h = Hash64Combine(h, std::hash<int>()(arg.second.size()));
131     for (int dim : arg.second) {
132       h = Hash64Combine(h, std::hash<int>()(dim));
133     }
134     return h;
135   }
136 };
137 
XlaSerializedCacheKeyToString(const XlaSerializedCacheKey & key)138 std::string XlaSerializedCacheKeyToString(const XlaSerializedCacheKey& key) {
139   return absl::StrCat(
140       key.prefix(), key.prefix().empty() ? "" : kXlaSerializedCacheKeySeparator,
141       key.signature_fingerprint(), kXlaSerializedCacheKeySeparator,
142       key.cluster_fingerprint(), kXlaSerializedCacheKeySeparator,
143       key.device_type());
144 }
145 
146 }  // namespace
147 
148 constexpr int64_t XlaCompilationCache::kDefaultCompilationThreshold;
149 constexpr int64_t
150     XlaCompilationCache::AsyncCompilationState::kNumCompilerThreads;
151 constexpr int64_t
152     XlaCompilationCache::AsyncCompilationState::kMaxNumOngoingCompilations;
153 
XlaCompilationCache(Config config,xla::LocalClient * client,DeviceType device_type)154 XlaCompilationCache::XlaCompilationCache(Config config,
155                                          xla::LocalClient* client,
156                                          DeviceType device_type)
157     : client_(client),
158       device_type_(std::move(device_type)),
159       disable_strict_signature_checks_(config.disable_strict_signature_checks),
160       persistance_prefix_(config.persistance_prefix),
161       persistent_cache_directory_(config.persistent_cache_directory) {}
162 
~XlaCompilationCache()163 XlaCompilationCache::~XlaCompilationCache() {
164   // Ensure any use of our programs have completed by waiting for all stream
165   // executors to complete.
166   for (auto* executor : client_->backend().stream_executors()) {
167     bool ok = executor->SynchronizeAllActivity();
168     if (!ok) {
169       LOG(ERROR) << "Error synchronizing activity while waiting for all "
170                     "programs to complete";
171     }
172   }
173   // Wait for all outstanding compilations to finish.
174   // Resetting the pointer explicitly in the top level destructor.
175   // Without this, the pointer would be reset when the AsyncCompilationState
176   // is destructed, which is dependent on the order of the members in the
177   // XlaCompilationCache class, which is error prone if the order changes.
178   async_compilation_state_.compiler_threads.reset();
179   // TODO(b/110813685): Think about the program ownership model. Programs are
180   // currently owned by the compilation cache which means we must wait for
181   // program completion in the destructor. There are multiple compilation caches
182   // around, which complicates things a little. Perhaps having programs be
183   // shared_ptrs (an invasive change) would make the model easier to reason
184   // about?
185 }
186 
DebugString() const187 string XlaCompilationCache::DebugString() const {
188   return "XLA JIT compilation cache";
189 }
190 
191 // Compute a string signature which encodes the shapes of the
192 // arguments in the supplied list.
HumanString() const193 string XlaCompilationCache::Signature::HumanString() const {
194   string result = name;
195   for (const auto& a : args) {
196     absl::visit(SignatureHumanStringAppender(&result), a);
197   }
198   return result;
199 }
200 
operator ==(const Signature & other) const201 bool XlaCompilationCache::Signature::operator==(const Signature& other) const {
202   if (name != other.name) return false;
203   if (args.size() != other.args.size()) return false;
204   for (int i = 0, end = args.size(); i < end; ++i) {
205     if (absl::visit(SignatureNotEqual(), args[i], other.args[i])) {
206       return false;
207     }
208   }
209   return true;
210 }
211 
operator ()(const XlaCompilationCache::Signature & signature) const212 uint64 XlaCompilationCache::Signature::Hash::operator()(
213     const XlaCompilationCache::Signature& signature) const {
214   uint64 h = std::hash<string>()(signature.name);
215   for (const auto& arg : signature.args) {
216     h = absl::visit(SignatureHashCombiner(h), arg);
217   }
218   return h;
219 }
220 
BuildSignature(const NameAttrList & function,absl::Span<const XlaCompiler::Argument> args)221 StatusOr<XlaCompilationCache::Signature> XlaCompilationCache::BuildSignature(
222     const NameAttrList& function,
223     absl::Span<const XlaCompiler::Argument> args) {
224   Signature signature;
225   signature.name = Canonicalize(function.name(), AttrSlice(&function.attr()));
226 
227   for (const XlaCompiler::Argument& arg : args) {
228     switch (arg.kind) {
229       case XlaCompiler::Argument::kConstant:
230       case XlaCompiler::Argument::kConstantResource:
231         signature.args.push_back(arg.constant_value);
232         break;
233       case XlaCompiler::Argument::kParameter:
234       case XlaCompiler::Argument::kResource:
235         signature.args.push_back(
236             TensorTypeAndShape(arg.type, arg.DimensionSizesAsInlinedVector()));
237         break;
238       default:
239         return errors::InvalidArgument(
240             "Unhandled argument kind in XlaCompilationCache: ",
241             arg.HumanString());
242     }
243   }
244   return std::move(signature);
245 }
246 
GetShapePointers(absl::Span<const xla::Shape> shapes)247 static std::vector<const xla::Shape*> GetShapePointers(
248     absl::Span<const xla::Shape> shapes) {
249   std::vector<const xla::Shape*> shape_ptrs;
250   shape_ptrs.reserve(shapes.size());
251   for (const auto& shape : shapes) {
252     shape_ptrs.push_back(&shape);
253   }
254   return shape_ptrs;
255 }
256 
GetBuildOptions(const XlaCompiler::Options & options,const XlaCompiler::CompilationResult & result,int default_device_ordinal)257 static xla::ExecutableBuildOptions GetBuildOptions(
258     const XlaCompiler::Options& options,
259     const XlaCompiler::CompilationResult& result, int default_device_ordinal) {
260   xla::ExecutableBuildOptions build_options;
261   if (result.collective_info) {
262     build_options.set_num_replicas(result.collective_info->group_size);
263   }
264   build_options.set_device_ordinal(options.device_ordinal != -1
265                                        ? options.device_ordinal
266                                        : default_device_ordinal);
267   build_options.set_result_layout(result.xla_output_shape);
268   build_options.set_device_allocator(options.device_allocator.get());
269   build_options.set_alias_passthrough_params(options.alias_passthrough_params);
270   build_options.mutable_debug_options()->set_xla_detailed_logging_and_dumping(
271       options.detailed_logging);
272   if (tensorflow::OpDeterminismRequired()) {
273     build_options.mutable_debug_options()->set_xla_gpu_deterministic_ops(true);
274   }
275   return build_options;
276 }
277 
BuildExecutable(const XlaCompiler::Options & options,const XlaCompiler::CompilationResult & result,std::unique_ptr<xla::LocalExecutable> * executable)278 Status XlaCompilationCache::BuildExecutable(
279     const XlaCompiler::Options& options,
280     const XlaCompiler::CompilationResult& result,
281     std::unique_ptr<xla::LocalExecutable>* executable) {
282   VLOG(2) << "Compiling to local executable";
283 
284   std::vector<const xla::Shape*> argument_layouts =
285       GetShapePointers(result.xla_input_shapes);
286   xla::ExecutableBuildOptions build_options =
287       GetBuildOptions(options, result, client_->default_device_ordinal());
288   TF_ASSIGN_OR_RETURN(
289       auto executables,
290       client_->Compile(*result.computation, argument_layouts, build_options));
291   TF_RET_CHECK(executables.size() == 1);
292   *executable = std::move(executables[0]);
293   return OkStatus();
294 }
295 
296 StatusOr<std::unique_ptr<xla::AotCompilationResult>>
BuildSerializedExecutable(const XlaCompiler::Options & options,const XlaCompiler::CompilationResult & result)297 XlaCompilationCache::BuildSerializedExecutable(
298     const XlaCompiler::Options& options,
299     const XlaCompiler::CompilationResult& result) {
300   VLOG(2) << "Compiling to local executable";
301 
302   std::vector<const xla::Shape*> argument_layouts =
303       GetShapePointers(result.xla_input_shapes);
304   xla::ExecutableBuildOptions build_options =
305       GetBuildOptions(options, result, client_->default_device_ordinal());
306   TF_ASSIGN_OR_RETURN(
307       std::vector<std::unique_ptr<xla::AotCompilationResult>> aot_results,
308       client_->CompileAheadOfTime(*result.computation, argument_layouts,
309                                   build_options));
310   TF_RET_CHECK(aot_results.size() == 1);
311   return std::move(aot_results[0]);
312 }
313 
314 StatusOr<std::unique_ptr<xla::LocalExecutable>>
LoadExecutable(const XlaCompiler::Options & options,const XlaCompiler::CompilationResult & result,const std::string & serialized_aot_result)315 XlaCompilationCache::LoadExecutable(
316     const XlaCompiler::Options& options,
317     const XlaCompiler::CompilationResult& result,
318     const std::string& serialized_aot_result) {
319   VLOG(2) << "Loading local executable using BEF.";
320 
321   xla::ExecutableBuildOptions build_options =
322       GetBuildOptions(options, result, client_->default_device_ordinal());
323   return client_->Load(serialized_aot_result, build_options);
324 }
325 
Compile(const XlaCompiler::Options & options,const NameAttrList & function,const std::vector<XlaCompiler::Argument> & args,const XlaCompiler::CompileOptions & compile_options,CompileMode compile_mode,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)326 Status XlaCompilationCache::Compile(
327     const XlaCompiler::Options& options, const NameAttrList& function,
328     const std::vector<XlaCompiler::Argument>& args,
329     const XlaCompiler::CompileOptions& compile_options,
330     CompileMode compile_mode,
331     const XlaCompiler::CompilationResult** out_compilation_result,
332     xla::LocalExecutable** out_executable) {
333   return CompileImpl(compile_options, options, function, args,
334                      /*ctx=*/nullptr, CompileScope::kFunction, compile_mode,
335                      out_compilation_result, out_executable);
336 }
337 
ShouldBeMegamorphic(int64_t compile_count,int64_t execution_count)338 static bool ShouldBeMegamorphic(int64_t compile_count,
339                                 int64_t execution_count) {
340   const int64_t kCompileThreshold = 10;
341   const int64_t kMinExecutionsPerCompile = 50;
342 
343   // This heuristic is trying to capture the following property: have we sunk a
344   // certain minimum amount of compile time into the cluster that didn't quite
345   // "pay off"?
346   return compile_count > kCompileThreshold &&
347          execution_count < kMinExecutionsPerCompile * compile_count;
348 }
349 
CreateGraph(const NodeDef & node_def,absl::Span<const XlaCompiler::Argument> args,absl::Span<const DataType> result_types)350 StatusOr<std::unique_ptr<Graph>> CreateGraph(
351     const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
352     absl::Span<const DataType> result_types) {
353   // TODO(b/74182462): We implement this by creating a new dummy Graph including
354   // _Arg nodes, and let CompileGraph walk it. This could be optimized.
355   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
356 
357   // First create the actual node we care about computing.
358   TF_ASSIGN_OR_RETURN(Node * main_node, graph->AddNode(node_def));
359 
360   // Create dummy _Arg nodes. Link these to `node` and also via a control
361   // dependency edge to the _SOURCE node.
362   for (int64_t i = 0, end = args.size(); i < end; ++i) {
363     Node* node;
364     string arg_name = absl::StrCat("_arg", i);
365     Status status =
366         NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp)
367             .ControlInput(graph->source_node())
368             .Attr("T", args[i].kind == XlaCompiler::Argument::kResource
369                            ? DT_RESOURCE
370                            : args[i].type)
371             .Attr("index", i)
372             .Finalize(graph.get(), &node);
373     TF_RETURN_IF_ERROR(status);
374     graph->AddEdge(node, 0, main_node, i);
375   }
376 
377   // Similarly with return values, create dummy _Retval nodes fed by `node`.
378   for (int64_t i = 0, end = result_types.size(); i < end; ++i) {
379     Node* node;
380     string retval_name = absl::StrCat("_retval", i);
381     Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp)
382                         .Input(main_node, i)
383                         .Attr("T", result_types[i])
384                         .Attr("index", i)
385                         .Finalize(graph.get(), &node);
386     TF_RETURN_IF_ERROR(status);
387   }
388   FixupSourceAndSinkEdges(graph.get());
389   return graph;
390 }
391 
XlaSingleOpToHlo(XlaCompiler * compiler,const XlaCompiler::Options & options,const std::vector<XlaCompiler::Argument> & args,const XlaCompiler::SingleOpCompileArgument & single_op_compile_argument,const XlaCompiler::CompileOptions & compile_options,XlaCompiler::CompilationResult * compilation_result)392 Status XlaSingleOpToHlo(
393     XlaCompiler* compiler, const XlaCompiler::Options& options,
394     const std::vector<XlaCompiler::Argument>& args,
395     const XlaCompiler::SingleOpCompileArgument& single_op_compile_argument,
396     const XlaCompiler::CompileOptions& compile_options,
397     XlaCompiler::CompilationResult* compilation_result) {
398   const std::vector<DataType>& result_dtypes =
399       single_op_compile_argument.output_dtypes;
400   const NodeDef& node_def = single_op_compile_argument.node_def;
401   TF_ASSIGN_OR_RETURN(
402       auto graph,
403       CreateGraph(node_def, args, single_op_compile_argument.output_dtypes));
404 
405   auto compile_with_old_bridge = [&]() {
406     *compilation_result = {};
407     return compiler->CompileGraph(compile_options, node_def.name(),
408                                   std::move(graph), args, compilation_result);
409   };
410 
411   const ConfigProto* config = &(single_op_compile_argument.config_proto);
412   auto bridge_rollout = GetMlirBridgeRolloutState(
413       config ? std::optional<ConfigProto>(*config) : std::nullopt);
414   if (bridge_rollout ==
415           ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED ||
416       node_def.op() == "VarIsInitializedOp" ||
417       (bridge_rollout !=
418            ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED &&
419        options.device_type.type_string() != DEVICE_TPU_XLA_JIT)) {
420     return compile_with_old_bridge();
421   }
422 
423   GraphDebugInfo debug_info;
424   std::vector<std::string> control_rets;
425   if (result_dtypes.empty()) {
426     control_rets.push_back(node_def.name());
427   }
428 
429   bool mlir_enabled = (bridge_rollout ==
430                        ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED);
431   VLOG(1) << "Attempting MLIR bridge."
432           << (mlir_enabled ? " MLIR is explicitly enabled." : "");
433   auto mlir_result = CompileGraphToXlaHlo(
434       *graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args), control_rets,
435       options.device_type.type_string(), compile_options.use_tuple_arg,
436       /*analyse_graph=*/!mlir_enabled, *options.flib_def, debug_info,
437       options.shape_determination_fns, compilation_result);
438 
439   if (mlir_result.ok() || mlir_enabled) {
440     return mlir_result;
441   }
442 
443   VLOG(2) << "Failed second phase of the MLIR bridge. Will "
444              "retry with the old bridge. MLIR bridge compilation status: "
445           << mlir_result;
446   return compile_with_old_bridge();
447 }
448 
CompileSingleOp(const XlaCompiler::Options & options,const std::vector<XlaCompiler::Argument> & args,OpKernelContext * ctx,const XlaCompiler::CompileOptions & compile_options,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)449 Status XlaCompilationCache::CompileSingleOp(
450     const XlaCompiler::Options& options,
451     const std::vector<XlaCompiler::Argument>& args, OpKernelContext* ctx,
452     const XlaCompiler::CompileOptions& compile_options,
453     const XlaCompiler::CompilationResult** out_compilation_result,
454     xla::LocalExecutable** out_executable) {
455   const NodeDef& def = ctx->op_kernel().def();
456   NameAttrList name;
457   name.set_name(def.op());
458   *name.mutable_attr() = def.attr();
459   // Remove the "_class" attribute from the attribute set used to create the
460   // compilation cache key. This attribute is information for the colocator
461   // and causes false uniqueness between nodes.
462   name.mutable_attr()->erase("_class");
463   return CompileImpl(compile_options, options, name, args, ctx,
464                      CompileScope::kOp, CompileMode::kStrict,
465                      out_compilation_result, out_executable);
466 }
467 
468 namespace {
469 // Print something that users can search for to definitively ascertain that XLA
470 // was used for their TF model.
471 //
472 // Prints only once to avoid spamming LOG(INFO).
LogOnceXlaCompiledFirstCluster()473 void LogOnceXlaCompiledFirstCluster() {
474   static absl::once_flag log_once;
475   absl::call_once(log_once, [] {
476     LOG(INFO) << "Compiled cluster using XLA!  This line is logged at most "
477                  "once for the lifetime of the process.";
478   });
479 }
480 }  // namespace
481 
CompileStrict(const Signature & sig,Entry * entry,const XlaCompiler::CompileOptions & compile_options,const XlaCompiler::Options & options,const std::vector<XlaCompiler::Argument> & args,const NameAttrList & function,OpKernelContext * ctx,CompileScope scope)482 Status XlaCompilationCache::CompileStrict(
483     const Signature& sig, Entry* entry,
484     const XlaCompiler::CompileOptions& compile_options,
485     const XlaCompiler::Options& options,
486     const std::vector<XlaCompiler::Argument>& args,
487     const NameAttrList& function, OpKernelContext* ctx, CompileScope scope) {
488   tensorflow::Env* env = tensorflow::Env::Default();
489   const uint64 compile_start_us = env->NowMicros();
490 
491   XlaCompiler compiler(options);
492   entry->compile_state = CompileState::kCompiled;
493   entry->compilation_status = [&] {
494     if (scope == CompileScope::kOp) {
495       XlaCompiler::SingleOpCompileArgument single_op_arg;
496       std::vector<DataType> output_dtypes(ctx->num_outputs());
497       for (int i = 0; i < output_dtypes.size(); ++i) {
498         output_dtypes[i] = ctx->expected_output_dtype(i);
499       }
500       single_op_arg.output_dtypes = std::move(output_dtypes);
501       single_op_arg.node_def = ctx->op_kernel().def();
502       auto* config_proto = ctx->function_library()->config_proto();
503       if (config_proto != nullptr) {
504         single_op_arg.config_proto = *config_proto;
505       }
506       return XlaSingleOpToHlo(&compiler, options, args, single_op_arg,
507                               compile_options, &entry->compilation_result);
508 
509     } else {
510       CHECK(scope == CompileScope::kFunction);  // Crash OK
511       return compiler.CompileFunction(compile_options, function, args,
512                                       &entry->compilation_result);
513     }
514   }();
515   TF_RETURN_IF_ERROR(entry->compilation_status);
516   TF_RET_CHECK(entry->executable.get() == nullptr);
517   TF_RET_CHECK(entry->compilation_result.computation != nullptr);
518 
519   std::optional<XlaSerializedCacheEntry> serialized_entry;
520   if (!persistent_cache_directory_.empty()) {
521     const xla::HloModuleProto& hlo_module =
522         entry->compilation_result.computation->proto();
523 
524     XlaSerializedCacheKey cache_key = BuildSerializedCacheKey(sig, hlo_module);
525 
526     {
527       XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
528           "Try loading serialized cache entry:", sig.HumanString()));
529       TF_ASSIGN_OR_RETURN(serialized_entry, TryLoadSerializedEntry(cache_key));
530     }
531 
532     if (serialized_entry.has_value()) {
533       TF_RETURN_IF_ERROR(
534           VerifyLoadedCacheEntry(cache_key, hlo_module, *serialized_entry));
535     }
536   }
537 
538   if (serialized_entry.has_value()) {
539     VLOG(1) << "Loading cached entry for: " << sig.HumanString();
540     StatusOr<std::unique_ptr<xla::LocalExecutable>> executable = LoadExecutable(
541         options, entry->compilation_result, serialized_entry->executable());
542     entry->compilation_status = executable.status();
543     if (executable.ok()) {
544       entry->executable = *std::move(executable);
545     }
546   } else {
547     entry->compilation_status =
548         BuildExecutable(options, entry->compilation_result, &entry->executable);
549 
550     // Caching is done regardless of the entry->compilation_status. To take
551     // advantage of newer compilation code, a cache flush is required.
552     if (!persistent_cache_directory_.empty()) {
553       XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
554           "Serializing and saving cache entry: ", sig.HumanString()));
555       TF_ASSIGN_OR_RETURN(XlaSerializedCacheEntry serialized_entry,
556                           SerializeEntry(options, sig, *entry));
557       TF_RETURN_IF_ERROR(SaveSerializedEntry(std::move(serialized_entry)));
558     }
559   }
560 
561   const uint64 compile_end_us = env->NowMicros();
562   const uint64 compile_time_us = compile_end_us - compile_start_us;
563   metrics::UpdateXlaCompilationTime(compile_time_us);
564 
565   mutex_lock lock(cluster_compile_stats_mu_);
566   const std::string& function_name = function.name();
567   auto it = cluster_compile_stats_.find(function_name);
568   const uint64 compile_time_s = compile_time_us / 1.0e6;
569   it->second.compile_count++;
570   it->second.cumulative_compile_time_us += compile_time_us;
571   LogOnceXlaCompiledFirstCluster();
572   VLOG(1) << "compiled " << function_name << " " << it->second.compile_count
573           << " times, compile time: " << compile_time_us
574           << " us, cumulative: " << it->second.cumulative_compile_time_us
575           << " us ("
576           << tensorflow::strings::HumanReadableElapsedTime(compile_time_s)
577           << " / "
578           << tensorflow::strings::HumanReadableElapsedTime(
579                  it->second.cumulative_compile_time_us / 1.0e6)
580           << ")";
581 
582   XlaJitCompilationActivity jit_compilation_activity;
583   jit_compilation_activity.set_cluster_name(function_name);
584   jit_compilation_activity.set_compile_count(it->second.compile_count);
585   jit_compilation_activity.set_compile_time_us(compile_time_us);
586   jit_compilation_activity.set_cumulative_compile_time_us(
587       it->second.cumulative_compile_time_us);
588   jit_compilation_activity.set_used_persistent_cache(
589       serialized_entry.has_value());
590   TF_RETURN_IF_ERROR(BroadcastXlaActivity(std::move(jit_compilation_activity)));
591 
592   return OkStatus();
593 }
594 
CompileAsynchronous(const Signature & signature,Entry * entry,const XlaCompiler::CompileOptions & compile_options,const XlaCompiler::Options & options,const std::vector<XlaCompiler::Argument> & args,const NameAttrList & function,OpKernelContext * ctx,CompileScope scope)595 Status XlaCompilationCache::CompileAsynchronous(
596     const Signature& signature, Entry* entry,
597     const XlaCompiler::CompileOptions& compile_options,
598     const XlaCompiler::Options& options,
599     const std::vector<XlaCompiler::Argument>& args,
600     const NameAttrList& function, OpKernelContext* ctx, CompileScope scope) {
601   // Explicitly capture all required data by value for async compilation.
602   entry->compile_state = CompileState::kCompiling;
603   {
604     mutex_lock lock(async_compilation_state_.async_compilation_state_mu);
605     async_compilation_state_.num_ongoing_compilations++;
606   }
607   // Don't move the above code into the thread function as it synchronously
608   // updates the async compilation state!
609 
610   // When the ThreadPool for the compilation cache is destroyed, it waits for
611   // compilations to have finished. This means that both 'entry' and 'this' will
612   // be alive for the duration of the compilation.
613   // !!Pay attention when additional variables must be captured by this lambda!!
614   // All values are captured by value. Make sure that all pointer values (like
615   // entry) do not get freed until the lambda has finished,\.
616   const std::string& function_name = function.name();
617   async_compilation_state_.compiler_threads->Schedule([=] {
618     Entry local_entry;
619     VLOG(2) << "Starting asynchronous compilation of cluster " << function_name
620             << '.';
621     // We don't need to lock local_entry.mu, but do it anyway to satisfy
622     // thread safety analysis.
623     mutex_lock entry_lock(local_entry.mu);
624     Status s = CompileStrict(signature, &local_entry, compile_options, options,
625                              args, function, ctx, scope);
626     VLOG(2) << "Finished asynchronous compililation of cluster "
627             << function_name << '.';
628     {
629       mutex_lock lock(async_compilation_state_.async_compilation_state_mu);
630       async_compilation_state_.num_ongoing_compilations--;
631     }
632     {  // Populate original entry with compilation result.
633       mutex_lock entry_lock(entry->mu);
634       if (!s.ok()) {
635         entry->compilation_status = s;
636       } else {
637         entry->compilation_status = local_entry.compilation_status;
638       }
639       entry->compilation_result = local_entry.compilation_result;
640       entry->compile_state = local_entry.compile_state;
641       entry->executable = std::move(local_entry.executable);
642     }
643   });
644   return OkStatus();
645 }
646 
ShouldCompileCluster(CompileMode compile_mode,bool is_first_execution,int64_t current_request_count,const NameAttrList & function)647 bool XlaCompilationCache::ShouldCompileCluster(CompileMode compile_mode,
648                                                bool is_first_execution,
649                                                int64_t current_request_count,
650                                                const NameAttrList& function) {
651   std::optional<int64_t> compile_threshold;
652   if (compile_mode == CompileMode::kLazy) {
653     compile_threshold = kDefaultCompilationThreshold;
654   } else if (compile_mode == CompileMode::kAsync) {
655     compile_threshold = 0;  // for now, always compile right away.
656   }
657 
658   if (compile_mode == CompileMode::kStrict) {
659     // Lazy compilation is disabled.
660     return true;
661   }
662 
663   if (is_first_execution) {
664     return true;
665   }
666 
667   if (compile_mode == CompileMode::kAsync) {
668     // Asynchronous compilation is enabled.
669     mutex_lock lock(async_compilation_state_.async_compilation_state_mu);
670     if (async_compilation_state_.num_ongoing_compilations >=
671         async_compilation_state_.kMaxNumOngoingCompilations) {
672       VLOG(2) << "Not asynchronously compiling cluster " << function.name()
673               << " because of too many ongoing compilations.";
674       return false;
675     }
676   }
677 
678   bool reached_compile_threshold = current_request_count >= *compile_threshold;
679   if (!reached_compile_threshold) {
680     VLOG(2) << "Not compiling cluster " << function.name()
681             << " because it has not reached compile threshold; threshold is "
682             << *compile_threshold << " execution count "
683             << current_request_count << ".";
684   }
685   return reached_compile_threshold;
686 }
687 
CompileImpl(const XlaCompiler::CompileOptions & compile_options,const XlaCompiler::Options & options,const NameAttrList & function,const std::vector<XlaCompiler::Argument> & args,OpKernelContext * ctx,CompileScope scope,CompileMode compile_mode,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)688 Status XlaCompilationCache::CompileImpl(
689     const XlaCompiler::CompileOptions& compile_options,
690     const XlaCompiler::Options& options, const NameAttrList& function,
691     const std::vector<XlaCompiler::Argument>& args, OpKernelContext* ctx,
692     CompileScope scope, CompileMode compile_mode,
693     const XlaCompiler::CompilationResult** out_compilation_result,
694     xla::LocalExecutable** out_executable) {
695   DCHECK_NE(out_executable, nullptr);
696   VLOG(2) << "XlaCompilationCache::Compile " << DebugString();
697 
698   if (VLOG_IS_ON(2)) {
699     VLOG(2) << "num_inputs=" << args.size();
700     for (int i = 0, end = args.size(); i < end; i++) {
701       VLOG(3) << i << ": " << args[i].HumanString();
702     }
703   }
704   TF_ASSIGN_OR_RETURN(Signature signature, BuildSignature(function, args));
705 
706   // The outer lock protects the existence of the cache entry. It does not
707   // protect the contents of the cache entry.
708   Entry* entry;
709   {
710     mutex_lock lock(compile_cache_mu_);
711     // Find or create a cache entry.
712     auto cache_entry = cache_.find(signature);
713     if (cache_entry == cache_.end()) {
714       auto inserted_entry =
715           cache_.emplace(signature, std::make_unique<Entry>());
716       cache_entry = inserted_entry.first;
717     }
718     entry = cache_entry->second.get();
719   }
720 
721   // We always compile a cluster the very first time it is executed.  This is an
722   // optimistic guess that pays off for statically shaped TensorFlow graphs
723   // (since they get the benefit of XLA right away without waiting for warmup)
724   // and doesn't hurt much for dynamically shaped TensorFlow graphs (we "pay" at
725   // most one cluster-compilation's worth of compile time).
726   bool is_first_execution;
727 
728   {
729     mutex_lock lock(cluster_compile_stats_mu_);
730     auto it =
731         cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{})
732             .first;
733     is_first_execution = it->second.execution_count++ == 0;
734   }
735 
736   string human_signature;
737   if (VLOG_IS_ON(2)) {
738     human_signature = VLOG_IS_ON(3) ? signature.HumanString() : function.name();
739     VLOG(2) << "Signature: " << human_signature;
740   }
741 
742   // Acquire the cache entry lock and compile, if necessary.
743   // TODO(phawkins): this locking will need to be restructured when we implement
744   // cache eviction.
745   mutex_lock entry_lock(entry->mu);
746   int64_t current_request_count = ++entry->request_count;
747   VLOG(2) << "Compilation cache entry hit: "
748           << static_cast<int>(entry->compile_state)
749           << " signature: " << human_signature << " with request count "
750           << current_request_count;
751 
752   CompileState state = entry->compile_state;
753   *out_compilation_result = nullptr;
754   *out_executable = nullptr;
755 
756   // Check if the requested entry is uncompiled and return an error if
757   // compilation is disabled. This will raise an error for kLazy even if we have
758   // not yet hit the compilation threshold and no compilation happens this
759   // round. This is to avoid non-determanism of when compilation is disallowed,
760   // for example by changing the threshold.
761   if (state == CompileState::kUncompiled && FailOnXlaCompilation()) {
762     VLOG(1) << "XLA compilation disabled: " << function.name() << "\n"
763             << absl::StrJoin(
764                    args, "\n",
765                    [](std::string* out, const XlaCompiler::Argument& arg) {
766                      absl::StrAppend(out, " arg: ", arg.HumanString());
767                    });
768 
769     return errors::Internal("XLA compilation disabled");
770   }
771 
772   if (state == CompileState::kUncompiled) {
773     XLA_SCOPED_LOGGING_TIMER("Compilation of XLA executable");
774     if (!ShouldCompileCluster(compile_mode, is_first_execution,
775                               current_request_count, function)) {
776       VLOG(2) << "Not compiling for signature: " << human_signature;
777       return OkStatus();
778     } else if (compile_mode == CompileMode::kAsync) {
779       VLOG(2) << "Queueing asynchronous compilation for signature: "
780               << human_signature;
781       TF_RETURN_IF_ERROR(CompileAsynchronous(signature, entry, compile_options,
782                                              options, args, function, ctx,
783                                              scope));
784       return OkStatus();
785     } else {
786       VLOG(2) << "Instantly compiling for signature: " << human_signature;
787       TF_RETURN_IF_ERROR(CompileStrict(signature, entry, compile_options,
788                                        options, args, function, ctx, scope));
789     }
790   } else if (state == CompileState::kCompiling) {
791     VLOG(2) << "Ongoing asynchronous compilation for signature: "
792             << human_signature;
793     return OkStatus();
794   } else if (state == CompileState::kCompiled) {
795     VLOG(2) << "Already Compiled for signature: " << human_signature;
796   }
797 
798   TF_RETURN_IF_ERROR(entry->compilation_status);
799   *out_compilation_result = &entry->compilation_result;
800   *out_executable = entry->executable.get();
801   return OkStatus();
802 }
803 
BuildSerializedCacheKey(const Signature & sig,const xla::HloModuleProto & hlo_module) const804 XlaSerializedCacheKey XlaCompilationCache::BuildSerializedCacheKey(
805     const Signature& sig, const xla::HloModuleProto& hlo_module) const {
806   XlaSerializedCacheKey serialized_cache_key;
807   serialized_cache_key.set_signature_fingerprint(Signature::Hash()(sig));
808   serialized_cache_key.set_cluster_fingerprint(
809       DeterministicProtoHash64(hlo_module));
810   serialized_cache_key.set_device_type(device_type_.type_string());
811   serialized_cache_key.set_prefix(persistance_prefix_);
812   return serialized_cache_key;
813 }
814 
VerifyLoadedCacheEntry(const XlaSerializedCacheKey & key,const xla::HloModuleProto & hlo_module,const XlaSerializedCacheEntry & entry)815 Status XlaCompilationCache::VerifyLoadedCacheEntry(
816     const XlaSerializedCacheKey& key, const xla::HloModuleProto& hlo_module,
817     const XlaSerializedCacheEntry& entry) {
818   XLA_SCOPED_LOGGING_TIMER(absl::StrCat("Verifying loaded cache entry: ",
819                                         hlo_module.entry_computation_name()));
820 
821   if (!AreSerializedProtosEqual(key, entry.key())) {
822     VLOG(2) << "Serialized cache key does not match:\n"
823             << "got:\n"
824             << entry.key().DebugString() << "\nexpected:\n"
825             << key.DebugString() << "\n";
826     return errors::InvalidArgument("Serialized cache key does not match.");
827   }
828 
829   // Perform a stricter (slower) check of the snapshot to verify that they
830   // match exactly.
831   if (!disable_strict_signature_checks_) {
832     if (!AreSerializedProtosEqual(hlo_module, entry.hlo_module())) {
833       VLOG(2) << "HLOs do not match:\n"
834               << "got:\n"
835               << hlo_module.DebugString() << "\nexpected:\n"
836               << entry.hlo_module().DebugString() << "\n";
837       return errors::InvalidArgument("Serialized HLO does not match.");
838     }
839   }
840 
841   if (entry.executable().empty()) {
842     return errors::InvalidArgument("No binary found in serialized entry.");
843   }
844   return OkStatus();
845 }
846 
SerializeEntry(const XlaCompiler::Options & options,const Signature & sig,const Entry & entry)847 StatusOr<XlaSerializedCacheEntry> XlaCompilationCache::SerializeEntry(
848     const XlaCompiler::Options& options, const Signature& sig,
849     const Entry& entry) {
850   if (entry.compile_state != CompileState::kCompiled) {
851     return errors::FailedPrecondition(
852         "Cache entry to serialize is not compiled.");
853   }
854   if (entry.executable == nullptr) {
855     return errors::FailedPrecondition(
856         "LocalExecutable not found for cache entry to serialize.");
857   }
858   if (entry.executable->executable() == nullptr) {
859     return errors::FailedPrecondition(
860         "Executable not found for cache entry to serialize.");
861   }
862 
863   XlaSerializedCacheEntry serialized_entry;
864   const xla::HloModuleProto& hlo_module =
865       entry.compilation_result.computation->proto();
866   *serialized_entry.mutable_key() = BuildSerializedCacheKey(sig, hlo_module);
867   *serialized_entry.mutable_hlo_module() = hlo_module;
868 
869   TF_ASSIGN_OR_RETURN(
870       std::unique_ptr<xla::AotCompilationResult> aot_result,
871       BuildSerializedExecutable(options, entry.compilation_result));
872   TF_ASSIGN_OR_RETURN(std::string serialized, aot_result->SerializeAsString());
873   serialized_entry.set_executable(std::move(serialized));
874   return serialized_entry;
875 }
876 
877 namespace {
878 
GetFilePath(const XlaSerializedCacheKey & key,absl::string_view persistent_cache_directory)879 std::string GetFilePath(const XlaSerializedCacheKey& key,
880                         absl::string_view persistent_cache_directory) {
881   const std::string file_name =
882       absl::StrCat(XlaSerializedCacheKeyToString(key), ".pb");
883   return io::JoinPath(persistent_cache_directory, file_name);
884 }
885 
886 }  // namespace
887 
SaveSerializedEntry(const XlaSerializedCacheEntry & entry)888 Status XlaCompilationCache::SaveSerializedEntry(
889     const XlaSerializedCacheEntry& entry) {
890   Env* env = Env::Default();
891   TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(persistent_cache_directory_));
892   const std::string file_path =
893       GetFilePath(entry.key(), persistent_cache_directory_);
894   return WriteBinaryProto(env, file_path, entry);
895 }
896 
897 StatusOr<std::optional<XlaSerializedCacheEntry>>
TryLoadSerializedEntry(const XlaSerializedCacheKey & key)898 XlaCompilationCache::TryLoadSerializedEntry(const XlaSerializedCacheKey& key) {
899   Env* env = Env::Default();
900   const std::string file_path = GetFilePath(key, persistent_cache_directory_);
901   if (!env->FileExists(file_path).ok()) {
902     return StatusOr<std::optional<XlaSerializedCacheEntry>>(std::nullopt);
903   }
904 
905   XlaSerializedCacheEntry entry;
906   TF_RETURN_IF_ERROR(ReadTextOrBinaryProto(env, file_path, &entry));
907   return StatusOr<std::optional<XlaSerializedCacheEntry>>(entry);
908 }
909 
910 }  // namespace tensorflow
911