1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_compilation_cache.h"
17
18 #include <memory>
19 #include <numeric>
20 #include <string>
21 #include <utility>
22 #include <variant>
23
24 #include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h"
25 #include "absl/base/call_once.h"
26 #include "absl/status/status.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_join.h"
29 #include "absl/types/variant.h"
30 #include "tensorflow/compiler/jit/flags.h"
31 #include "tensorflow/compiler/jit/xla_activity.pb.h"
32 #include "tensorflow/compiler/jit/xla_activity_listener.h"
33 #include "tensorflow/compiler/jit/xla_cluster_util.h"
34 #include "tensorflow/compiler/jit/xla_compilation_cache.pb.h"
35 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
36 #include "tensorflow/compiler/mlir/utils/array_container_utils.h"
37 #include "tensorflow/compiler/tf2xla/shape_util.h"
38 #include "tensorflow/compiler/tf2xla/type_util.h"
39 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
40 #include "tensorflow/compiler/tf2xla/xla_context.h"
41 #include "tensorflow/compiler/xla/client/client_library.h"
42 #include "tensorflow/compiler/xla/protobuf_util.h"
43 #include "tensorflow/compiler/xla/service/compiler.h"
44 #include "tensorflow/compiler/xla/service/hlo.pb.h"
45 #include "tensorflow/compiler/xla/status_macros.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/core/common_runtime/device.h"
48 #include "tensorflow/core/common_runtime/function.h"
49 #include "tensorflow/core/common_runtime/graph_constructor.h"
50 #include "tensorflow/core/common_runtime/graph_optimizer.h"
51 #include "tensorflow/core/framework/attr_value_util.h"
52 #include "tensorflow/core/framework/graph.pb.h"
53 #include "tensorflow/core/framework/metrics.h"
54 #include "tensorflow/core/framework/op_kernel.h"
55 #include "tensorflow/core/framework/types.h"
56 #include "tensorflow/core/graph/algorithm.h"
57 #include "tensorflow/core/graph/node_builder.h"
58 #include "tensorflow/core/lib/hash/hash.h"
59 #include "tensorflow/core/lib/strings/proto_serialization.h"
60 #include "tensorflow/core/platform/env.h"
61 #include "tensorflow/core/platform/errors.h"
62 #include "tensorflow/core/platform/fingerprint.h"
63 #include "tensorflow/core/platform/logging.h"
64 #include "tensorflow/core/platform/path.h"
65 #include "tensorflow/core/platform/status.h"
66 #include "tensorflow/core/platform/statusor.h"
67 #include "tensorflow/core/protobuf/debug_event.pb.h"
68 #include "tensorflow/core/protobuf/error_codes.pb.h"
69 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
70 #include "tensorflow/core/public/version.h"
71 #include "tensorflow/core/tpu/tpu_defs.h"
72 #include "tensorflow/core/util/determinism.h"
73 #include "tensorflow/core/util/dump_graph.h"
74
75 namespace tensorflow {
76 namespace {
77
78 using TensorTypeAndShape = XlaCompilationCache::Signature::TensorTypeAndShape;
79
80 constexpr char kXlaSerializedCacheKeySeparator[] = "__";
81
82 // Functor that converts a Signature's arg to a human readable string.
83 struct SignatureHumanStringAppender {
SignatureHumanStringAppendertensorflow::__anon9b0cc7d50111::SignatureHumanStringAppender84 explicit SignatureHumanStringAppender(string* dest) : dest(dest) {}
85 string* dest;
operator ()tensorflow::__anon9b0cc7d50111::SignatureHumanStringAppender86 void operator()(const Tensor& arg) {
87 absl::StrAppend(dest, "; ", arg.DebugString());
88 }
operator ()tensorflow::__anon9b0cc7d50111::SignatureHumanStringAppender89 void operator()(const TensorTypeAndShape& arg) {
90 absl::StrAppend(dest, ",", DataTypeString(arg.first));
91 absl::StrAppend(dest, " [", absl::StrJoin(arg.second, ","), "]");
92 }
93 };
94
95 // Functor that compares the arg values of two different signatures. Returns
96 // true when the args are not equal.
97 struct SignatureNotEqual {
operator ()tensorflow::__anon9b0cc7d50111::SignatureNotEqual98 bool operator()(const Tensor& arg, const Tensor& other) {
99 return arg.dtype() != other.dtype() || arg.shape() != other.shape() ||
100 arg.tensor_data() != other.tensor_data();
101 }
operator ()tensorflow::__anon9b0cc7d50111::SignatureNotEqual102 bool operator()(const TensorTypeAndShape& arg,
103 const TensorTypeAndShape& other) {
104 return arg.first != other.first || arg.second != other.second;
105 }
operator ()tensorflow::__anon9b0cc7d50111::SignatureNotEqual106 bool operator()(const Tensor& arg, const TensorTypeAndShape& other) {
107 return true;
108 }
operator ()tensorflow::__anon9b0cc7d50111::SignatureNotEqual109 bool operator()(const TensorTypeAndShape& arg, const Tensor& other) {
110 return true;
111 }
112 };
113
114 // Functor that incrementally computes a Signature's hash given its current hash
115 // and one of its args.
116 struct SignatureHashCombiner {
SignatureHashCombinertensorflow::__anon9b0cc7d50111::SignatureHashCombiner117 explicit SignatureHashCombiner(const uint64 h) : h(h) {}
118 uint64 h;
operator ()tensorflow::__anon9b0cc7d50111::SignatureHashCombiner119 uint64 operator()(const Tensor& arg) {
120 h = Hash64Combine(h, std::hash<int>()(static_cast<int>(arg.dtype())));
121 h = Hash64Combine(
122 h, Hash64(arg.tensor_data().data(), arg.tensor_data().size()));
123 for (int dim = 0; dim < arg.dims(); ++dim) {
124 h = Hash64Combine(h, std::hash<int>()(arg.dim_size(dim)));
125 }
126 return h;
127 }
operator ()tensorflow::__anon9b0cc7d50111::SignatureHashCombiner128 uint64 operator()(const TensorTypeAndShape& arg) {
129 h = Hash64Combine(h, std::hash<int>()(static_cast<int>(arg.first)));
130 h = Hash64Combine(h, std::hash<int>()(arg.second.size()));
131 for (int dim : arg.second) {
132 h = Hash64Combine(h, std::hash<int>()(dim));
133 }
134 return h;
135 }
136 };
137
XlaSerializedCacheKeyToString(const XlaSerializedCacheKey & key)138 std::string XlaSerializedCacheKeyToString(const XlaSerializedCacheKey& key) {
139 return absl::StrCat(
140 key.prefix(), key.prefix().empty() ? "" : kXlaSerializedCacheKeySeparator,
141 key.signature_fingerprint(), kXlaSerializedCacheKeySeparator,
142 key.cluster_fingerprint(), kXlaSerializedCacheKeySeparator,
143 key.device_type());
144 }
145
146 } // namespace
147
148 constexpr int64_t XlaCompilationCache::kDefaultCompilationThreshold;
149 constexpr int64_t
150 XlaCompilationCache::AsyncCompilationState::kNumCompilerThreads;
151 constexpr int64_t
152 XlaCompilationCache::AsyncCompilationState::kMaxNumOngoingCompilations;
153
XlaCompilationCache(Config config,xla::LocalClient * client,DeviceType device_type)154 XlaCompilationCache::XlaCompilationCache(Config config,
155 xla::LocalClient* client,
156 DeviceType device_type)
157 : client_(client),
158 device_type_(std::move(device_type)),
159 disable_strict_signature_checks_(config.disable_strict_signature_checks),
160 persistance_prefix_(config.persistance_prefix),
161 persistent_cache_directory_(config.persistent_cache_directory) {}
162
~XlaCompilationCache()163 XlaCompilationCache::~XlaCompilationCache() {
164 // Ensure any use of our programs have completed by waiting for all stream
165 // executors to complete.
166 for (auto* executor : client_->backend().stream_executors()) {
167 bool ok = executor->SynchronizeAllActivity();
168 if (!ok) {
169 LOG(ERROR) << "Error synchronizing activity while waiting for all "
170 "programs to complete";
171 }
172 }
173 // Wait for all outstanding compilations to finish.
174 // Resetting the pointer explicitly in the top level destructor.
175 // Without this, the pointer would be reset when the AsyncCompilationState
176 // is destructed, which is dependent on the order of the members in the
177 // XlaCompilationCache class, which is error prone if the order changes.
178 async_compilation_state_.compiler_threads.reset();
179 // TODO(b/110813685): Think about the program ownership model. Programs are
180 // currently owned by the compilation cache which means we must wait for
181 // program completion in the destructor. There are multiple compilation caches
182 // around, which complicates things a little. Perhaps having programs be
183 // shared_ptrs (an invasive change) would make the model easier to reason
184 // about?
185 }
186
DebugString() const187 string XlaCompilationCache::DebugString() const {
188 return "XLA JIT compilation cache";
189 }
190
191 // Compute a string signature which encodes the shapes of the
192 // arguments in the supplied list.
HumanString() const193 string XlaCompilationCache::Signature::HumanString() const {
194 string result = name;
195 for (const auto& a : args) {
196 absl::visit(SignatureHumanStringAppender(&result), a);
197 }
198 return result;
199 }
200
operator ==(const Signature & other) const201 bool XlaCompilationCache::Signature::operator==(const Signature& other) const {
202 if (name != other.name) return false;
203 if (args.size() != other.args.size()) return false;
204 for (int i = 0, end = args.size(); i < end; ++i) {
205 if (absl::visit(SignatureNotEqual(), args[i], other.args[i])) {
206 return false;
207 }
208 }
209 return true;
210 }
211
operator ()(const XlaCompilationCache::Signature & signature) const212 uint64 XlaCompilationCache::Signature::Hash::operator()(
213 const XlaCompilationCache::Signature& signature) const {
214 uint64 h = std::hash<string>()(signature.name);
215 for (const auto& arg : signature.args) {
216 h = absl::visit(SignatureHashCombiner(h), arg);
217 }
218 return h;
219 }
220
BuildSignature(const NameAttrList & function,absl::Span<const XlaCompiler::Argument> args)221 StatusOr<XlaCompilationCache::Signature> XlaCompilationCache::BuildSignature(
222 const NameAttrList& function,
223 absl::Span<const XlaCompiler::Argument> args) {
224 Signature signature;
225 signature.name = Canonicalize(function.name(), AttrSlice(&function.attr()));
226
227 for (const XlaCompiler::Argument& arg : args) {
228 switch (arg.kind) {
229 case XlaCompiler::Argument::kConstant:
230 case XlaCompiler::Argument::kConstantResource:
231 signature.args.push_back(arg.constant_value);
232 break;
233 case XlaCompiler::Argument::kParameter:
234 case XlaCompiler::Argument::kResource:
235 signature.args.push_back(
236 TensorTypeAndShape(arg.type, arg.DimensionSizesAsInlinedVector()));
237 break;
238 default:
239 return errors::InvalidArgument(
240 "Unhandled argument kind in XlaCompilationCache: ",
241 arg.HumanString());
242 }
243 }
244 return std::move(signature);
245 }
246
GetShapePointers(absl::Span<const xla::Shape> shapes)247 static std::vector<const xla::Shape*> GetShapePointers(
248 absl::Span<const xla::Shape> shapes) {
249 std::vector<const xla::Shape*> shape_ptrs;
250 shape_ptrs.reserve(shapes.size());
251 for (const auto& shape : shapes) {
252 shape_ptrs.push_back(&shape);
253 }
254 return shape_ptrs;
255 }
256
GetBuildOptions(const XlaCompiler::Options & options,const XlaCompiler::CompilationResult & result,int default_device_ordinal)257 static xla::ExecutableBuildOptions GetBuildOptions(
258 const XlaCompiler::Options& options,
259 const XlaCompiler::CompilationResult& result, int default_device_ordinal) {
260 xla::ExecutableBuildOptions build_options;
261 if (result.collective_info) {
262 build_options.set_num_replicas(result.collective_info->group_size);
263 }
264 build_options.set_device_ordinal(options.device_ordinal != -1
265 ? options.device_ordinal
266 : default_device_ordinal);
267 build_options.set_result_layout(result.xla_output_shape);
268 build_options.set_device_allocator(options.device_allocator.get());
269 build_options.set_alias_passthrough_params(options.alias_passthrough_params);
270 build_options.mutable_debug_options()->set_xla_detailed_logging_and_dumping(
271 options.detailed_logging);
272 if (tensorflow::OpDeterminismRequired()) {
273 build_options.mutable_debug_options()->set_xla_gpu_deterministic_ops(true);
274 }
275 return build_options;
276 }
277
BuildExecutable(const XlaCompiler::Options & options,const XlaCompiler::CompilationResult & result,std::unique_ptr<xla::LocalExecutable> * executable)278 Status XlaCompilationCache::BuildExecutable(
279 const XlaCompiler::Options& options,
280 const XlaCompiler::CompilationResult& result,
281 std::unique_ptr<xla::LocalExecutable>* executable) {
282 VLOG(2) << "Compiling to local executable";
283
284 std::vector<const xla::Shape*> argument_layouts =
285 GetShapePointers(result.xla_input_shapes);
286 xla::ExecutableBuildOptions build_options =
287 GetBuildOptions(options, result, client_->default_device_ordinal());
288 TF_ASSIGN_OR_RETURN(
289 auto executables,
290 client_->Compile(*result.computation, argument_layouts, build_options));
291 TF_RET_CHECK(executables.size() == 1);
292 *executable = std::move(executables[0]);
293 return OkStatus();
294 }
295
296 StatusOr<std::unique_ptr<xla::AotCompilationResult>>
BuildSerializedExecutable(const XlaCompiler::Options & options,const XlaCompiler::CompilationResult & result)297 XlaCompilationCache::BuildSerializedExecutable(
298 const XlaCompiler::Options& options,
299 const XlaCompiler::CompilationResult& result) {
300 VLOG(2) << "Compiling to local executable";
301
302 std::vector<const xla::Shape*> argument_layouts =
303 GetShapePointers(result.xla_input_shapes);
304 xla::ExecutableBuildOptions build_options =
305 GetBuildOptions(options, result, client_->default_device_ordinal());
306 TF_ASSIGN_OR_RETURN(
307 std::vector<std::unique_ptr<xla::AotCompilationResult>> aot_results,
308 client_->CompileAheadOfTime(*result.computation, argument_layouts,
309 build_options));
310 TF_RET_CHECK(aot_results.size() == 1);
311 return std::move(aot_results[0]);
312 }
313
314 StatusOr<std::unique_ptr<xla::LocalExecutable>>
LoadExecutable(const XlaCompiler::Options & options,const XlaCompiler::CompilationResult & result,const std::string & serialized_aot_result)315 XlaCompilationCache::LoadExecutable(
316 const XlaCompiler::Options& options,
317 const XlaCompiler::CompilationResult& result,
318 const std::string& serialized_aot_result) {
319 VLOG(2) << "Loading local executable using BEF.";
320
321 xla::ExecutableBuildOptions build_options =
322 GetBuildOptions(options, result, client_->default_device_ordinal());
323 return client_->Load(serialized_aot_result, build_options);
324 }
325
Compile(const XlaCompiler::Options & options,const NameAttrList & function,const std::vector<XlaCompiler::Argument> & args,const XlaCompiler::CompileOptions & compile_options,CompileMode compile_mode,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)326 Status XlaCompilationCache::Compile(
327 const XlaCompiler::Options& options, const NameAttrList& function,
328 const std::vector<XlaCompiler::Argument>& args,
329 const XlaCompiler::CompileOptions& compile_options,
330 CompileMode compile_mode,
331 const XlaCompiler::CompilationResult** out_compilation_result,
332 xla::LocalExecutable** out_executable) {
333 return CompileImpl(compile_options, options, function, args,
334 /*ctx=*/nullptr, CompileScope::kFunction, compile_mode,
335 out_compilation_result, out_executable);
336 }
337
ShouldBeMegamorphic(int64_t compile_count,int64_t execution_count)338 static bool ShouldBeMegamorphic(int64_t compile_count,
339 int64_t execution_count) {
340 const int64_t kCompileThreshold = 10;
341 const int64_t kMinExecutionsPerCompile = 50;
342
343 // This heuristic is trying to capture the following property: have we sunk a
344 // certain minimum amount of compile time into the cluster that didn't quite
345 // "pay off"?
346 return compile_count > kCompileThreshold &&
347 execution_count < kMinExecutionsPerCompile * compile_count;
348 }
349
CreateGraph(const NodeDef & node_def,absl::Span<const XlaCompiler::Argument> args,absl::Span<const DataType> result_types)350 StatusOr<std::unique_ptr<Graph>> CreateGraph(
351 const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
352 absl::Span<const DataType> result_types) {
353 // TODO(b/74182462): We implement this by creating a new dummy Graph including
354 // _Arg nodes, and let CompileGraph walk it. This could be optimized.
355 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
356
357 // First create the actual node we care about computing.
358 TF_ASSIGN_OR_RETURN(Node * main_node, graph->AddNode(node_def));
359
360 // Create dummy _Arg nodes. Link these to `node` and also via a control
361 // dependency edge to the _SOURCE node.
362 for (int64_t i = 0, end = args.size(); i < end; ++i) {
363 Node* node;
364 string arg_name = absl::StrCat("_arg", i);
365 Status status =
366 NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp)
367 .ControlInput(graph->source_node())
368 .Attr("T", args[i].kind == XlaCompiler::Argument::kResource
369 ? DT_RESOURCE
370 : args[i].type)
371 .Attr("index", i)
372 .Finalize(graph.get(), &node);
373 TF_RETURN_IF_ERROR(status);
374 graph->AddEdge(node, 0, main_node, i);
375 }
376
377 // Similarly with return values, create dummy _Retval nodes fed by `node`.
378 for (int64_t i = 0, end = result_types.size(); i < end; ++i) {
379 Node* node;
380 string retval_name = absl::StrCat("_retval", i);
381 Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp)
382 .Input(main_node, i)
383 .Attr("T", result_types[i])
384 .Attr("index", i)
385 .Finalize(graph.get(), &node);
386 TF_RETURN_IF_ERROR(status);
387 }
388 FixupSourceAndSinkEdges(graph.get());
389 return graph;
390 }
391
XlaSingleOpToHlo(XlaCompiler * compiler,const XlaCompiler::Options & options,const std::vector<XlaCompiler::Argument> & args,const XlaCompiler::SingleOpCompileArgument & single_op_compile_argument,const XlaCompiler::CompileOptions & compile_options,XlaCompiler::CompilationResult * compilation_result)392 Status XlaSingleOpToHlo(
393 XlaCompiler* compiler, const XlaCompiler::Options& options,
394 const std::vector<XlaCompiler::Argument>& args,
395 const XlaCompiler::SingleOpCompileArgument& single_op_compile_argument,
396 const XlaCompiler::CompileOptions& compile_options,
397 XlaCompiler::CompilationResult* compilation_result) {
398 const std::vector<DataType>& result_dtypes =
399 single_op_compile_argument.output_dtypes;
400 const NodeDef& node_def = single_op_compile_argument.node_def;
401 TF_ASSIGN_OR_RETURN(
402 auto graph,
403 CreateGraph(node_def, args, single_op_compile_argument.output_dtypes));
404
405 auto compile_with_old_bridge = [&]() {
406 *compilation_result = {};
407 return compiler->CompileGraph(compile_options, node_def.name(),
408 std::move(graph), args, compilation_result);
409 };
410
411 const ConfigProto* config = &(single_op_compile_argument.config_proto);
412 auto bridge_rollout = GetMlirBridgeRolloutState(
413 config ? std::optional<ConfigProto>(*config) : std::nullopt);
414 if (bridge_rollout ==
415 ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED ||
416 node_def.op() == "VarIsInitializedOp" ||
417 (bridge_rollout !=
418 ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED &&
419 options.device_type.type_string() != DEVICE_TPU_XLA_JIT)) {
420 return compile_with_old_bridge();
421 }
422
423 GraphDebugInfo debug_info;
424 std::vector<std::string> control_rets;
425 if (result_dtypes.empty()) {
426 control_rets.push_back(node_def.name());
427 }
428
429 bool mlir_enabled = (bridge_rollout ==
430 ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED);
431 VLOG(1) << "Attempting MLIR bridge."
432 << (mlir_enabled ? " MLIR is explicitly enabled." : "");
433 auto mlir_result = CompileGraphToXlaHlo(
434 *graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args), control_rets,
435 options.device_type.type_string(), compile_options.use_tuple_arg,
436 /*analyse_graph=*/!mlir_enabled, *options.flib_def, debug_info,
437 options.shape_determination_fns, compilation_result);
438
439 if (mlir_result.ok() || mlir_enabled) {
440 return mlir_result;
441 }
442
443 VLOG(2) << "Failed second phase of the MLIR bridge. Will "
444 "retry with the old bridge. MLIR bridge compilation status: "
445 << mlir_result;
446 return compile_with_old_bridge();
447 }
448
CompileSingleOp(const XlaCompiler::Options & options,const std::vector<XlaCompiler::Argument> & args,OpKernelContext * ctx,const XlaCompiler::CompileOptions & compile_options,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)449 Status XlaCompilationCache::CompileSingleOp(
450 const XlaCompiler::Options& options,
451 const std::vector<XlaCompiler::Argument>& args, OpKernelContext* ctx,
452 const XlaCompiler::CompileOptions& compile_options,
453 const XlaCompiler::CompilationResult** out_compilation_result,
454 xla::LocalExecutable** out_executable) {
455 const NodeDef& def = ctx->op_kernel().def();
456 NameAttrList name;
457 name.set_name(def.op());
458 *name.mutable_attr() = def.attr();
459 // Remove the "_class" attribute from the attribute set used to create the
460 // compilation cache key. This attribute is information for the colocator
461 // and causes false uniqueness between nodes.
462 name.mutable_attr()->erase("_class");
463 return CompileImpl(compile_options, options, name, args, ctx,
464 CompileScope::kOp, CompileMode::kStrict,
465 out_compilation_result, out_executable);
466 }
467
468 namespace {
469 // Print something that users can search for to definitively ascertain that XLA
470 // was used for their TF model.
471 //
472 // Prints only once to avoid spamming LOG(INFO).
LogOnceXlaCompiledFirstCluster()473 void LogOnceXlaCompiledFirstCluster() {
474 static absl::once_flag log_once;
475 absl::call_once(log_once, [] {
476 LOG(INFO) << "Compiled cluster using XLA! This line is logged at most "
477 "once for the lifetime of the process.";
478 });
479 }
480 } // namespace
481
CompileStrict(const Signature & sig,Entry * entry,const XlaCompiler::CompileOptions & compile_options,const XlaCompiler::Options & options,const std::vector<XlaCompiler::Argument> & args,const NameAttrList & function,OpKernelContext * ctx,CompileScope scope)482 Status XlaCompilationCache::CompileStrict(
483 const Signature& sig, Entry* entry,
484 const XlaCompiler::CompileOptions& compile_options,
485 const XlaCompiler::Options& options,
486 const std::vector<XlaCompiler::Argument>& args,
487 const NameAttrList& function, OpKernelContext* ctx, CompileScope scope) {
488 tensorflow::Env* env = tensorflow::Env::Default();
489 const uint64 compile_start_us = env->NowMicros();
490
491 XlaCompiler compiler(options);
492 entry->compile_state = CompileState::kCompiled;
493 entry->compilation_status = [&] {
494 if (scope == CompileScope::kOp) {
495 XlaCompiler::SingleOpCompileArgument single_op_arg;
496 std::vector<DataType> output_dtypes(ctx->num_outputs());
497 for (int i = 0; i < output_dtypes.size(); ++i) {
498 output_dtypes[i] = ctx->expected_output_dtype(i);
499 }
500 single_op_arg.output_dtypes = std::move(output_dtypes);
501 single_op_arg.node_def = ctx->op_kernel().def();
502 auto* config_proto = ctx->function_library()->config_proto();
503 if (config_proto != nullptr) {
504 single_op_arg.config_proto = *config_proto;
505 }
506 return XlaSingleOpToHlo(&compiler, options, args, single_op_arg,
507 compile_options, &entry->compilation_result);
508
509 } else {
510 CHECK(scope == CompileScope::kFunction); // Crash OK
511 return compiler.CompileFunction(compile_options, function, args,
512 &entry->compilation_result);
513 }
514 }();
515 TF_RETURN_IF_ERROR(entry->compilation_status);
516 TF_RET_CHECK(entry->executable.get() == nullptr);
517 TF_RET_CHECK(entry->compilation_result.computation != nullptr);
518
519 std::optional<XlaSerializedCacheEntry> serialized_entry;
520 if (!persistent_cache_directory_.empty()) {
521 const xla::HloModuleProto& hlo_module =
522 entry->compilation_result.computation->proto();
523
524 XlaSerializedCacheKey cache_key = BuildSerializedCacheKey(sig, hlo_module);
525
526 {
527 XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
528 "Try loading serialized cache entry:", sig.HumanString()));
529 TF_ASSIGN_OR_RETURN(serialized_entry, TryLoadSerializedEntry(cache_key));
530 }
531
532 if (serialized_entry.has_value()) {
533 TF_RETURN_IF_ERROR(
534 VerifyLoadedCacheEntry(cache_key, hlo_module, *serialized_entry));
535 }
536 }
537
538 if (serialized_entry.has_value()) {
539 VLOG(1) << "Loading cached entry for: " << sig.HumanString();
540 StatusOr<std::unique_ptr<xla::LocalExecutable>> executable = LoadExecutable(
541 options, entry->compilation_result, serialized_entry->executable());
542 entry->compilation_status = executable.status();
543 if (executable.ok()) {
544 entry->executable = *std::move(executable);
545 }
546 } else {
547 entry->compilation_status =
548 BuildExecutable(options, entry->compilation_result, &entry->executable);
549
550 // Caching is done regardless of the entry->compilation_status. To take
551 // advantage of newer compilation code, a cache flush is required.
552 if (!persistent_cache_directory_.empty()) {
553 XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
554 "Serializing and saving cache entry: ", sig.HumanString()));
555 TF_ASSIGN_OR_RETURN(XlaSerializedCacheEntry serialized_entry,
556 SerializeEntry(options, sig, *entry));
557 TF_RETURN_IF_ERROR(SaveSerializedEntry(std::move(serialized_entry)));
558 }
559 }
560
561 const uint64 compile_end_us = env->NowMicros();
562 const uint64 compile_time_us = compile_end_us - compile_start_us;
563 metrics::UpdateXlaCompilationTime(compile_time_us);
564
565 mutex_lock lock(cluster_compile_stats_mu_);
566 const std::string& function_name = function.name();
567 auto it = cluster_compile_stats_.find(function_name);
568 const uint64 compile_time_s = compile_time_us / 1.0e6;
569 it->second.compile_count++;
570 it->second.cumulative_compile_time_us += compile_time_us;
571 LogOnceXlaCompiledFirstCluster();
572 VLOG(1) << "compiled " << function_name << " " << it->second.compile_count
573 << " times, compile time: " << compile_time_us
574 << " us, cumulative: " << it->second.cumulative_compile_time_us
575 << " us ("
576 << tensorflow::strings::HumanReadableElapsedTime(compile_time_s)
577 << " / "
578 << tensorflow::strings::HumanReadableElapsedTime(
579 it->second.cumulative_compile_time_us / 1.0e6)
580 << ")";
581
582 XlaJitCompilationActivity jit_compilation_activity;
583 jit_compilation_activity.set_cluster_name(function_name);
584 jit_compilation_activity.set_compile_count(it->second.compile_count);
585 jit_compilation_activity.set_compile_time_us(compile_time_us);
586 jit_compilation_activity.set_cumulative_compile_time_us(
587 it->second.cumulative_compile_time_us);
588 jit_compilation_activity.set_used_persistent_cache(
589 serialized_entry.has_value());
590 TF_RETURN_IF_ERROR(BroadcastXlaActivity(std::move(jit_compilation_activity)));
591
592 return OkStatus();
593 }
594
CompileAsynchronous(const Signature & signature,Entry * entry,const XlaCompiler::CompileOptions & compile_options,const XlaCompiler::Options & options,const std::vector<XlaCompiler::Argument> & args,const NameAttrList & function,OpKernelContext * ctx,CompileScope scope)595 Status XlaCompilationCache::CompileAsynchronous(
596 const Signature& signature, Entry* entry,
597 const XlaCompiler::CompileOptions& compile_options,
598 const XlaCompiler::Options& options,
599 const std::vector<XlaCompiler::Argument>& args,
600 const NameAttrList& function, OpKernelContext* ctx, CompileScope scope) {
601 // Explicitly capture all required data by value for async compilation.
602 entry->compile_state = CompileState::kCompiling;
603 {
604 mutex_lock lock(async_compilation_state_.async_compilation_state_mu);
605 async_compilation_state_.num_ongoing_compilations++;
606 }
607 // Don't move the above code into the thread function as it synchronously
608 // updates the async compilation state!
609
610 // When the ThreadPool for the compilation cache is destroyed, it waits for
611 // compilations to have finished. This means that both 'entry' and 'this' will
612 // be alive for the duration of the compilation.
613 // !!Pay attention when additional variables must be captured by this lambda!!
614 // All values are captured by value. Make sure that all pointer values (like
615 // entry) do not get freed until the lambda has finished,\.
616 const std::string& function_name = function.name();
617 async_compilation_state_.compiler_threads->Schedule([=] {
618 Entry local_entry;
619 VLOG(2) << "Starting asynchronous compilation of cluster " << function_name
620 << '.';
621 // We don't need to lock local_entry.mu, but do it anyway to satisfy
622 // thread safety analysis.
623 mutex_lock entry_lock(local_entry.mu);
624 Status s = CompileStrict(signature, &local_entry, compile_options, options,
625 args, function, ctx, scope);
626 VLOG(2) << "Finished asynchronous compililation of cluster "
627 << function_name << '.';
628 {
629 mutex_lock lock(async_compilation_state_.async_compilation_state_mu);
630 async_compilation_state_.num_ongoing_compilations--;
631 }
632 { // Populate original entry with compilation result.
633 mutex_lock entry_lock(entry->mu);
634 if (!s.ok()) {
635 entry->compilation_status = s;
636 } else {
637 entry->compilation_status = local_entry.compilation_status;
638 }
639 entry->compilation_result = local_entry.compilation_result;
640 entry->compile_state = local_entry.compile_state;
641 entry->executable = std::move(local_entry.executable);
642 }
643 });
644 return OkStatus();
645 }
646
ShouldCompileCluster(CompileMode compile_mode,bool is_first_execution,int64_t current_request_count,const NameAttrList & function)647 bool XlaCompilationCache::ShouldCompileCluster(CompileMode compile_mode,
648 bool is_first_execution,
649 int64_t current_request_count,
650 const NameAttrList& function) {
651 std::optional<int64_t> compile_threshold;
652 if (compile_mode == CompileMode::kLazy) {
653 compile_threshold = kDefaultCompilationThreshold;
654 } else if (compile_mode == CompileMode::kAsync) {
655 compile_threshold = 0; // for now, always compile right away.
656 }
657
658 if (compile_mode == CompileMode::kStrict) {
659 // Lazy compilation is disabled.
660 return true;
661 }
662
663 if (is_first_execution) {
664 return true;
665 }
666
667 if (compile_mode == CompileMode::kAsync) {
668 // Asynchronous compilation is enabled.
669 mutex_lock lock(async_compilation_state_.async_compilation_state_mu);
670 if (async_compilation_state_.num_ongoing_compilations >=
671 async_compilation_state_.kMaxNumOngoingCompilations) {
672 VLOG(2) << "Not asynchronously compiling cluster " << function.name()
673 << " because of too many ongoing compilations.";
674 return false;
675 }
676 }
677
678 bool reached_compile_threshold = current_request_count >= *compile_threshold;
679 if (!reached_compile_threshold) {
680 VLOG(2) << "Not compiling cluster " << function.name()
681 << " because it has not reached compile threshold; threshold is "
682 << *compile_threshold << " execution count "
683 << current_request_count << ".";
684 }
685 return reached_compile_threshold;
686 }
687
CompileImpl(const XlaCompiler::CompileOptions & compile_options,const XlaCompiler::Options & options,const NameAttrList & function,const std::vector<XlaCompiler::Argument> & args,OpKernelContext * ctx,CompileScope scope,CompileMode compile_mode,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)688 Status XlaCompilationCache::CompileImpl(
689 const XlaCompiler::CompileOptions& compile_options,
690 const XlaCompiler::Options& options, const NameAttrList& function,
691 const std::vector<XlaCompiler::Argument>& args, OpKernelContext* ctx,
692 CompileScope scope, CompileMode compile_mode,
693 const XlaCompiler::CompilationResult** out_compilation_result,
694 xla::LocalExecutable** out_executable) {
695 DCHECK_NE(out_executable, nullptr);
696 VLOG(2) << "XlaCompilationCache::Compile " << DebugString();
697
698 if (VLOG_IS_ON(2)) {
699 VLOG(2) << "num_inputs=" << args.size();
700 for (int i = 0, end = args.size(); i < end; i++) {
701 VLOG(3) << i << ": " << args[i].HumanString();
702 }
703 }
704 TF_ASSIGN_OR_RETURN(Signature signature, BuildSignature(function, args));
705
706 // The outer lock protects the existence of the cache entry. It does not
707 // protect the contents of the cache entry.
708 Entry* entry;
709 {
710 mutex_lock lock(compile_cache_mu_);
711 // Find or create a cache entry.
712 auto cache_entry = cache_.find(signature);
713 if (cache_entry == cache_.end()) {
714 auto inserted_entry =
715 cache_.emplace(signature, std::make_unique<Entry>());
716 cache_entry = inserted_entry.first;
717 }
718 entry = cache_entry->second.get();
719 }
720
721 // We always compile a cluster the very first time it is executed. This is an
722 // optimistic guess that pays off for statically shaped TensorFlow graphs
723 // (since they get the benefit of XLA right away without waiting for warmup)
724 // and doesn't hurt much for dynamically shaped TensorFlow graphs (we "pay" at
725 // most one cluster-compilation's worth of compile time).
726 bool is_first_execution;
727
728 {
729 mutex_lock lock(cluster_compile_stats_mu_);
730 auto it =
731 cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{})
732 .first;
733 is_first_execution = it->second.execution_count++ == 0;
734 }
735
736 string human_signature;
737 if (VLOG_IS_ON(2)) {
738 human_signature = VLOG_IS_ON(3) ? signature.HumanString() : function.name();
739 VLOG(2) << "Signature: " << human_signature;
740 }
741
742 // Acquire the cache entry lock and compile, if necessary.
743 // TODO(phawkins): this locking will need to be restructured when we implement
744 // cache eviction.
745 mutex_lock entry_lock(entry->mu);
746 int64_t current_request_count = ++entry->request_count;
747 VLOG(2) << "Compilation cache entry hit: "
748 << static_cast<int>(entry->compile_state)
749 << " signature: " << human_signature << " with request count "
750 << current_request_count;
751
752 CompileState state = entry->compile_state;
753 *out_compilation_result = nullptr;
754 *out_executable = nullptr;
755
756 // Check if the requested entry is uncompiled and return an error if
757 // compilation is disabled. This will raise an error for kLazy even if we have
758 // not yet hit the compilation threshold and no compilation happens this
759 // round. This is to avoid non-determanism of when compilation is disallowed,
760 // for example by changing the threshold.
761 if (state == CompileState::kUncompiled && FailOnXlaCompilation()) {
762 VLOG(1) << "XLA compilation disabled: " << function.name() << "\n"
763 << absl::StrJoin(
764 args, "\n",
765 [](std::string* out, const XlaCompiler::Argument& arg) {
766 absl::StrAppend(out, " arg: ", arg.HumanString());
767 });
768
769 return errors::Internal("XLA compilation disabled");
770 }
771
772 if (state == CompileState::kUncompiled) {
773 XLA_SCOPED_LOGGING_TIMER("Compilation of XLA executable");
774 if (!ShouldCompileCluster(compile_mode, is_first_execution,
775 current_request_count, function)) {
776 VLOG(2) << "Not compiling for signature: " << human_signature;
777 return OkStatus();
778 } else if (compile_mode == CompileMode::kAsync) {
779 VLOG(2) << "Queueing asynchronous compilation for signature: "
780 << human_signature;
781 TF_RETURN_IF_ERROR(CompileAsynchronous(signature, entry, compile_options,
782 options, args, function, ctx,
783 scope));
784 return OkStatus();
785 } else {
786 VLOG(2) << "Instantly compiling for signature: " << human_signature;
787 TF_RETURN_IF_ERROR(CompileStrict(signature, entry, compile_options,
788 options, args, function, ctx, scope));
789 }
790 } else if (state == CompileState::kCompiling) {
791 VLOG(2) << "Ongoing asynchronous compilation for signature: "
792 << human_signature;
793 return OkStatus();
794 } else if (state == CompileState::kCompiled) {
795 VLOG(2) << "Already Compiled for signature: " << human_signature;
796 }
797
798 TF_RETURN_IF_ERROR(entry->compilation_status);
799 *out_compilation_result = &entry->compilation_result;
800 *out_executable = entry->executable.get();
801 return OkStatus();
802 }
803
BuildSerializedCacheKey(const Signature & sig,const xla::HloModuleProto & hlo_module) const804 XlaSerializedCacheKey XlaCompilationCache::BuildSerializedCacheKey(
805 const Signature& sig, const xla::HloModuleProto& hlo_module) const {
806 XlaSerializedCacheKey serialized_cache_key;
807 serialized_cache_key.set_signature_fingerprint(Signature::Hash()(sig));
808 serialized_cache_key.set_cluster_fingerprint(
809 DeterministicProtoHash64(hlo_module));
810 serialized_cache_key.set_device_type(device_type_.type_string());
811 serialized_cache_key.set_prefix(persistance_prefix_);
812 return serialized_cache_key;
813 }
814
VerifyLoadedCacheEntry(const XlaSerializedCacheKey & key,const xla::HloModuleProto & hlo_module,const XlaSerializedCacheEntry & entry)815 Status XlaCompilationCache::VerifyLoadedCacheEntry(
816 const XlaSerializedCacheKey& key, const xla::HloModuleProto& hlo_module,
817 const XlaSerializedCacheEntry& entry) {
818 XLA_SCOPED_LOGGING_TIMER(absl::StrCat("Verifying loaded cache entry: ",
819 hlo_module.entry_computation_name()));
820
821 if (!AreSerializedProtosEqual(key, entry.key())) {
822 VLOG(2) << "Serialized cache key does not match:\n"
823 << "got:\n"
824 << entry.key().DebugString() << "\nexpected:\n"
825 << key.DebugString() << "\n";
826 return errors::InvalidArgument("Serialized cache key does not match.");
827 }
828
829 // Perform a stricter (slower) check of the snapshot to verify that they
830 // match exactly.
831 if (!disable_strict_signature_checks_) {
832 if (!AreSerializedProtosEqual(hlo_module, entry.hlo_module())) {
833 VLOG(2) << "HLOs do not match:\n"
834 << "got:\n"
835 << hlo_module.DebugString() << "\nexpected:\n"
836 << entry.hlo_module().DebugString() << "\n";
837 return errors::InvalidArgument("Serialized HLO does not match.");
838 }
839 }
840
841 if (entry.executable().empty()) {
842 return errors::InvalidArgument("No binary found in serialized entry.");
843 }
844 return OkStatus();
845 }
846
SerializeEntry(const XlaCompiler::Options & options,const Signature & sig,const Entry & entry)847 StatusOr<XlaSerializedCacheEntry> XlaCompilationCache::SerializeEntry(
848 const XlaCompiler::Options& options, const Signature& sig,
849 const Entry& entry) {
850 if (entry.compile_state != CompileState::kCompiled) {
851 return errors::FailedPrecondition(
852 "Cache entry to serialize is not compiled.");
853 }
854 if (entry.executable == nullptr) {
855 return errors::FailedPrecondition(
856 "LocalExecutable not found for cache entry to serialize.");
857 }
858 if (entry.executable->executable() == nullptr) {
859 return errors::FailedPrecondition(
860 "Executable not found for cache entry to serialize.");
861 }
862
863 XlaSerializedCacheEntry serialized_entry;
864 const xla::HloModuleProto& hlo_module =
865 entry.compilation_result.computation->proto();
866 *serialized_entry.mutable_key() = BuildSerializedCacheKey(sig, hlo_module);
867 *serialized_entry.mutable_hlo_module() = hlo_module;
868
869 TF_ASSIGN_OR_RETURN(
870 std::unique_ptr<xla::AotCompilationResult> aot_result,
871 BuildSerializedExecutable(options, entry.compilation_result));
872 TF_ASSIGN_OR_RETURN(std::string serialized, aot_result->SerializeAsString());
873 serialized_entry.set_executable(std::move(serialized));
874 return serialized_entry;
875 }
876
877 namespace {
878
GetFilePath(const XlaSerializedCacheKey & key,absl::string_view persistent_cache_directory)879 std::string GetFilePath(const XlaSerializedCacheKey& key,
880 absl::string_view persistent_cache_directory) {
881 const std::string file_name =
882 absl::StrCat(XlaSerializedCacheKeyToString(key), ".pb");
883 return io::JoinPath(persistent_cache_directory, file_name);
884 }
885
886 } // namespace
887
SaveSerializedEntry(const XlaSerializedCacheEntry & entry)888 Status XlaCompilationCache::SaveSerializedEntry(
889 const XlaSerializedCacheEntry& entry) {
890 Env* env = Env::Default();
891 TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(persistent_cache_directory_));
892 const std::string file_path =
893 GetFilePath(entry.key(), persistent_cache_directory_);
894 return WriteBinaryProto(env, file_path, entry);
895 }
896
897 StatusOr<std::optional<XlaSerializedCacheEntry>>
TryLoadSerializedEntry(const XlaSerializedCacheKey & key)898 XlaCompilationCache::TryLoadSerializedEntry(const XlaSerializedCacheKey& key) {
899 Env* env = Env::Default();
900 const std::string file_path = GetFilePath(key, persistent_cache_directory_);
901 if (!env->FileExists(file_path).ok()) {
902 return StatusOr<std::optional<XlaSerializedCacheEntry>>(std::nullopt);
903 }
904
905 XlaSerializedCacheEntry entry;
906 TF_RETURN_IF_ERROR(ReadTextOrBinaryProto(env, file_path, &entry));
907 return StatusOr<std::optional<XlaSerializedCacheEntry>>(entry);
908 }
909
910 } // namespace tensorflow
911