1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/kernels/xla_ops.h"
17
18 #include <map>
19 #include <memory>
20 #include <tuple>
21 #include <utility>
22
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
26 #include "tensorflow/compiler/jit/flags.h"
27 #include "tensorflow/compiler/jit/xla_activity_listener.h"
28 #include "tensorflow/compiler/jit/xla_platform_info.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
30 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
31 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
32 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
33 #include "tensorflow/compiler/xla/client/local_client.h"
34 #include "tensorflow/compiler/xla/executable_run_options.h"
35 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
36 #include "tensorflow/compiler/xla/statusor.h"
37 #include "tensorflow/core/framework/allocator.h"
38 #include "tensorflow/core/framework/node_def_util.h"
39 #include "tensorflow/core/framework/op_kernel.h"
40 #include "tensorflow/core/framework/op_requires.h"
41 #include "tensorflow/core/framework/tensor.h"
42 #include "tensorflow/core/framework/types.h"
43 #include "tensorflow/core/lib/monitoring/counter.h"
44 #include "tensorflow/core/platform/env.h"
45 #include "tensorflow/core/platform/refcount.h"
46 #include "tensorflow/core/platform/statusor.h"
47 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
48 #include "tensorflow/core/profiler/lib/traceme.h"
49 #include "tensorflow/core/util/stream_executor_util.h"
50
51 // OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
52 // in error case, it returns RET instead of void.
53 #define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
54 do { \
55 ::tensorflow::Status _s(__VA_ARGS__); \
56 if (!TF_PREDICT_TRUE(_s.ok())) { \
57 (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
58 return RET; \
59 } \
60 } while (0)
61
62 namespace tensorflow {
63
64 namespace {
65
66 auto* xla_launch_counter = monitoring::Counter<1>::New(
67 "/tensorflow/core/xla_launch_counter",
68 "The number of times a XlaLaunch is called.", "device");
69
70 // A closure describing how to run a compiled version of a TensorFlow function.
71 //
72 // It may seem unusual to stick the resource variable snapshots in this class.
73 // This is necessary: we need to use the snapshots observed by the compiler as
74 // the initial values for the resource variables (and cannot snapshot them again
75 // during execution) because otherwise we risk observing a different snapshot
76 // with shapes different from what we compiled for.
77 class XlaExecutableClosure {
78 public:
XlaExecutableClosure(xla::LocalClient * client,xla::LocalExecutable * executable,const XlaCompiler::CompilationResult * compilation_result,ResourceVarsSnapshot resource_var_snapshots,int num_constant_args)79 explicit XlaExecutableClosure(
80 xla::LocalClient* client, xla::LocalExecutable* executable,
81 const XlaCompiler::CompilationResult* compilation_result,
82 ResourceVarsSnapshot resource_var_snapshots, int num_constant_args)
83 : client_(client),
84 executable_(executable),
85 compilation_result_(compilation_result),
86 resource_var_snapshots_(std::move(resource_var_snapshots)),
87 num_constant_args_(num_constant_args) {}
88
89 XlaExecutableClosure(XlaExecutableClosure&&) = default;
90 XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default;
91
client() const92 xla::LocalClient* client() const { return client_; }
executable() const93 xla::LocalExecutable* executable() const { return executable_; }
compilation_result() const94 const XlaCompiler::CompilationResult* compilation_result() const {
95 return compilation_result_;
96 }
resource_var_snapshots() const97 const ResourceVarsSnapshot& resource_var_snapshots() const {
98 return resource_var_snapshots_;
99 }
num_constant_args() const100 int num_constant_args() const { return num_constant_args_; }
101
102 private:
103 xla::LocalClient* client_;
104 xla::LocalExecutable* executable_;
105 const XlaCompiler::CompilationResult* compilation_result_;
106 ResourceVarsSnapshot resource_var_snapshots_;
107 int num_constant_args_;
108
109 TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
110 };
111
112 // This maintains a mapping from a globally unique ID to XlaExecutableClosure
113 // instances.
114 class XlaExecutableClosureStore {
115 public:
XlaExecutableClosureStore()116 XlaExecutableClosureStore() : key_counter_(0) {}
117
118 using KeyT = string;
119
Produce(XlaExecutableClosure result)120 KeyT Produce(XlaExecutableClosure result) {
121 mutex_lock l(mutex_);
122 KeyT key = absl::StrCat(key_counter_++);
123 bool insert_successful = closures_.emplace(key, std::move(result)).second;
124 DCHECK(insert_successful);
125 (void)insert_successful;
126 return key;
127 }
128
Consume(const KeyT & key)129 XlaExecutableClosure Consume(const KeyT& key) {
130 mutex_lock l(mutex_);
131 auto it = closures_.find(key);
132 DCHECK(it != closures_.end());
133 XlaExecutableClosure value = std::move(it->second);
134 closures_.erase(it);
135 return value;
136 }
137
Global()138 static XlaExecutableClosureStore* Global() {
139 static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore;
140 return instance;
141 }
142
143 private:
144 mutex mutex_;
145 int64_t key_counter_ TF_GUARDED_BY(mutex_);
146 absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_
147 TF_GUARDED_BY(mutex_);
148
149 TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
150 };
151
GetStream(OpKernelContext * ctx)152 se::Stream* GetStream(OpKernelContext* ctx) {
153 return ctx->op_device_context() ? ctx->op_device_context()->stream()
154 : nullptr;
155 }
156
GetLaunchContext(const XlaPlatformInfo & platform_info,OpKernelContext * ctx,xla::LocalClient * client,se::DeviceMemoryAllocator * allocator)157 XlaComputationLaunchContext GetLaunchContext(
158 const XlaPlatformInfo& platform_info, OpKernelContext* ctx,
159 xla::LocalClient* client, se::DeviceMemoryAllocator* allocator) {
160 se::Stream* stream = GetStream(ctx);
161 int device_ordinal = stream ? stream->parent()->device_ordinal()
162 : client->default_device_ordinal();
163 XlaComputationLaunchContext launch_context(
164 client, allocator, device_ordinal,
165 /*allocate_xla_tensors=*/platform_info.is_on_xla_device(),
166 /*use_multiple_streams=*/platform_info.UseMultipleStreams());
167 return launch_context;
168 }
169
RunExecutable(const XlaPlatformInfo & platform_info,const XlaComputationLaunchContext & launch_context,std::vector<xla::ExecutionInput> execution_inputs,xla::ExecutableRunOptions run_options,xla::LocalExecutable * executable,OpKernelContext * ctx,se::DeviceMemoryAllocator * allocator)170 StatusOr<xla::ExecutionOutput> RunExecutable(
171 const XlaPlatformInfo& platform_info,
172 const XlaComputationLaunchContext& launch_context,
173 std::vector<xla::ExecutionInput> execution_inputs,
174 xla::ExecutableRunOptions run_options, xla::LocalExecutable* executable,
175 OpKernelContext* ctx, se::DeviceMemoryAllocator* allocator) {
176 VLOG(2) << "Executing Xla Computation.";
177 Env* env = Env::Default();
178 auto start_time = env->NowMicros();
179
180 se::Stream* stream = GetStream(ctx);
181 run_options.set_stream(GetStream(ctx));
182 run_options.set_allocator(allocator);
183 run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
184 run_options.set_rng_seed(GetXLARandomSeed());
185 StatusOr<xla::ExecutionOutput> execution_output;
186 bool run_synchronous =
187 !stream || platform_info.platform_id() == se::host::kHostPlatformId;
188 if (run_synchronous) {
189 execution_output =
190 executable->Run(std::move(execution_inputs), run_options);
191 } else {
192 execution_output =
193 executable->RunAsync(std::move(execution_inputs), run_options);
194 }
195
196 auto elapsed = env->NowMicros() - start_time;
197 VLOG(2) << "Elapsed time for Xla Executable Run: " << elapsed << "us";
198 return execution_output;
199 }
200
201 StatusOr<std::pair<std::vector<XlaCompiler::Argument>, ResourceVarsSnapshot>>
GetXlaCompilerArgsAndSnapshotVariables(absl::Span<const int> variable_indices,absl::Span<const int> must_be_constant_idxs,absl::Span<const Tensor * const> inputs,OpKernelContext * ctx)202 GetXlaCompilerArgsAndSnapshotVariables(
203 absl::Span<const int> variable_indices,
204 absl::Span<const int> must_be_constant_idxs,
205 absl::Span<const Tensor* const> inputs, OpKernelContext* ctx) {
206 std::pair<std::vector<XlaCompiler::Argument>, ResourceVarsSnapshot> result;
207
208 std::vector<VariableInfo> variable_infos;
209 TF_RETURN_IF_ERROR(
210 GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(), inputs,
211 variable_indices, &variable_infos));
212 TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
213
214 TF_RETURN_IF_ERROR(SnapshotResourceVariables(ctx, variable_indices,
215 variable_infos, &result.second));
216
217 TF_ASSIGN_OR_RETURN(result.first,
218 XlaComputationLaunchContext::BuildXlaCompilerArguments(
219 must_be_constant_idxs, inputs, variable_infos,
220 static_cast<Device*>(ctx->device())));
221 return result;
222 }
223
224 } // namespace
225
XlaLocalLaunchBase(OpKernelConstruction * ctx,const std::vector<int> & constants,const std::vector<int> & resources,const NameAttrList & function,bool has_ref_vars)226 XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
227 const std::vector<int>& constants,
228 const std::vector<int>& resources,
229 const NameAttrList& function,
230 bool has_ref_vars)
231 : AsyncOpKernel(ctx),
232 constants_(constants),
233 resources_(resources),
234 function_(function),
235 platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
236 has_ref_vars_(has_ref_vars) {}
237
CompileToLocalExecutable(OpKernelContext * ctx,const NameAttrList & function,bool has_ref_vars,const XlaPlatformInfo & platform_info,const std::vector<XlaCompiler::Argument> & args,XlaCompilationCache::CompileMode compile_mode,bool may_alias_resource_update,xla::LocalClient ** client,const XlaCompiler::CompilationResult ** compilation_result,xla::LocalExecutable ** executable)238 static Status CompileToLocalExecutable(
239 OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
240 const XlaPlatformInfo& platform_info,
241 const std::vector<XlaCompiler::Argument>& args,
242 XlaCompilationCache::CompileMode compile_mode,
243 bool may_alias_resource_update, xla::LocalClient** client,
244 const XlaCompiler::CompilationResult** compilation_result,
245 xla::LocalExecutable** executable) {
246 // We store information about the JIT-compiled XLA computation
247 // in the ResourceMgr.
248 ResourceMgr* rm = ctx->resource_manager();
249 if (!rm) {
250 return errors::Internal("No resource manager.");
251 }
252
253 XlaCompilationCache* cache;
254 TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
255 rm->default_container(), "xla_cache", &cache,
256 [&](XlaCompilationCache** cache) {
257 return BuildXlaCompilationCache(ctx->device(), ctx->function_library(),
258 platform_info, cache);
259 }));
260 // Hold the reference to the JIT during evaluation. (We could probably
261 // free it sooner because the ResourceMgr will retain a reference, but
262 // this is more obviously correct.)
263 core::ScopedUnref cache_ref(cache);
264
265 *client = static_cast<xla::LocalClient*>(cache->client());
266
267 XlaCompiler::Options options =
268 GenerateCompilerOptions(*cache, *ctx->function_library(), ctx->device(),
269 GetStream(ctx), platform_info, has_ref_vars);
270
271 XlaCompiler::CompileOptions compile_options;
272 compile_options.is_entry_computation = true;
273 // Optimization: where possible, have the computation return a naked array
274 // rather than a one-element tuple.
275 compile_options.always_return_tuple = false;
276 compile_options.alias_resource_update =
277 !has_ref_vars && may_alias_resource_update;
278
279 return cache->Compile(options, function, args, compile_options, compile_mode,
280 compilation_result, executable);
281 }
282
283 // Get-or-create thread pool for a given collective.
GetOrCreateThreadPoolForCollective(const XlaCompilationResult::CollectiveInfo & collective_info)284 static thread::ThreadPool* GetOrCreateThreadPoolForCollective(
285 const XlaCompilationResult::CollectiveInfo& collective_info) {
286 static absl::Mutex m(absl::kConstInit);
287 static auto& thread_pool_cache ABSL_GUARDED_BY(m) =
288 *new absl::node_hash_map<XlaCompilationResult::CollectiveInfo,
289 thread::ThreadPool>();
290 absl::MutexLock l(&m);
291 auto it = thread_pool_cache.find(collective_info);
292 if (it == thread_pool_cache.end()) {
293 // Create & cache thread pool.
294 auto inserted_it = thread_pool_cache.emplace(
295 std::piecewise_construct, std::forward_as_tuple(collective_info),
296 std::forward_as_tuple(Env::Default(), "xla_collective_thread_pool",
297 collective_info.group_size));
298 return &inserted_it.first->second;
299 }
300 return &it->second;
301 }
302
ComputeAsync(OpKernelContext * ctx,DoneCallback done)303 void XlaLocalLaunchBase::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
304 VLOG(1) << "XlaLocalLaunchOpBase::Compute "
305 << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
306 xla_launch_counter->GetCell(platform_info_.device_type().type_string())
307 ->IncrementBy(1);
308
309 std::vector<const Tensor*> inputs = InputsFromContext(ctx);
310 xla::LocalClient* client;
311 const XlaCompiler::CompilationResult* compilation_result;
312 xla::LocalExecutable* executable;
313
314 auto args_and_variables_snapshot = GetXlaCompilerArgsAndSnapshotVariables(
315 resources_, constants_, inputs, ctx);
316 OP_REQUIRES_OK_ASYNC(ctx, args_and_variables_snapshot.status(), done);
317 const std::vector<XlaCompiler::Argument>& args =
318 args_and_variables_snapshot->first;
319 ResourceVarsSnapshot& variables_snapshot =
320 args_and_variables_snapshot->second;
321
322 const Status s = CompileToLocalExecutable(
323 ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, args,
324 XlaCompilationCache::CompileMode::kStrict,
325 /*may_alias_resource_update=*/true, &client, &compilation_result,
326 &executable);
327 OP_REQUIRES_OK_ASYNC(ctx, s, done);
328
329 // Continuation of the execution, may be run in a different thread.
330 auto run_xla_cluster = [ctx, variables_snapshot, client, executable,
331 compilation_result, done, inputs,
332 resources = resources_] {
333 auto platform_info = XlaPlatformInfoFromDevice(ctx->device());
334 std::map<int, const Tensor*> snapshot_ptrs;
335 for (const auto& [variable_index, variable_tensor] : variables_snapshot) {
336 snapshot_ptrs.emplace(variable_index, variable_tensor.has_value()
337 ? &variable_tensor.value()
338 : nullptr);
339 }
340
341 std::shared_ptr<se::DeviceMemoryAllocator> allocator =
342 GetAllocator(ctx->device(), GetStream(ctx), platform_info);
343 XlaComputationLaunchContext launch_context =
344 GetLaunchContext(platform_info, ctx, client, allocator.get());
345
346 const xla::HloInputOutputAliasConfig& input_output_alias =
347 executable->executable()->module().input_output_alias_config();
348 StatusOr<std::vector<xla::ExecutionInput>> execution_inputs =
349 launch_context.PopulateInputs(ctx, compilation_result, snapshot_ptrs,
350 /*missing_ctx_input_prefix=*/0,
351 input_output_alias);
352 OP_REQUIRES_OK_ASYNC(ctx, execution_inputs.status(), done);
353
354 xla::gpu::GpuExecutableRunOptions gpu_options;
355 xla::DeviceAssignment device_assignment;
356 xla::ExecutableRunOptions run_options;
357 if (compilation_result->collective_info.has_value()) {
358 OP_REQUIRES_OK_ASYNC(
359 ctx,
360 ResolveDeviceAssignment(ctx, *compilation_result->collective_info,
361 run_options, device_assignment, gpu_options),
362 done);
363 }
364
365 // Hardcode run id to always be zero: TF distributed strategy differentiates
366 // between subsequent runs using dependency edges.
367 // This is safe, as only TF dist-strat can produce distributed ops, and we
368 // can rely on TF dist-strat invariants.
369 xla::RunId run_id(0);
370 run_options.set_run_id(run_id);
371
372 StatusOr<xla::ExecutionOutput> execution_output = RunExecutable(
373 platform_info, launch_context, std::move(*execution_inputs),
374 run_options, executable, ctx, allocator.get());
375 OP_REQUIRES_ASYNC(ctx, execution_output.ok(), execution_output.status(),
376 done);
377
378 std::vector<VariableInfo> variable_infos;
379 OP_REQUIRES_OK_ASYNC(
380 ctx,
381 GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(),
382 inputs, resources, &variable_infos),
383 done);
384 OP_REQUIRES_OK_ASYNC(ctx, LockVariables(absl::MakeSpan(variable_infos)),
385 done);
386
387 OP_REQUIRES_OK_ASYNC(
388 ctx,
389 launch_context.PopulateOutputs(
390 ctx, compilation_result, execution_output->ConsumeResult(),
391 /*missing_ctx_input_prefix=*/0, absl::MakeSpan(variable_infos),
392 input_output_alias, snapshot_ptrs),
393 done);
394 VLOG(1) << "Done";
395 done();
396 };
397
398 // If we are using collectives, we need to run in a separate threadpool.
399 if (compilation_result->collective_info.has_value()) {
400 GetOrCreateThreadPoolForCollective(*compilation_result->collective_info)
401 ->Schedule(run_xla_cluster);
402 } else {
403 // Otherwise, just run normally: we merely "pretend" to be asynchronous.
404 run_xla_cluster();
405 }
406 }
407
408 namespace {
409 // Helper static functions to construct parameters for
410 // XlaLocalLaunchBase constructor from OpKernelConstruction.
ConstantsVector(OpKernelConstruction * ctx)411 std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
412 DataTypeVector constant_types;
413 OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
414 ctx->GetAttr("Tconstants", &constant_types));
415 std::vector<int> constants(constant_types.size());
416 std::iota(constants.begin(), constants.end(), 0);
417 return constants;
418 }
419
ResourcesVector(OpKernelConstruction * ctx)420 std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
421 DataTypeVector constant_types;
422 OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
423 ctx->GetAttr("Tconstants", &constant_types));
424
425 DataTypeVector arg_types;
426 OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
427 ctx->GetAttr("Targs", &arg_types));
428
429 int num_resources;
430 OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
431 ctx->GetAttr("Nresources", &num_resources));
432
433 std::vector<int> resources(num_resources);
434 std::iota(resources.begin(), resources.end(),
435 constant_types.size() + arg_types.size());
436 return resources;
437 }
438
FunctionAttr(OpKernelConstruction * ctx)439 NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
440 const NameAttrList* func;
441 OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
442 return *func;
443 }
444
MustCompileAttr(OpKernelConstruction * ctx)445 bool MustCompileAttr(OpKernelConstruction* ctx) {
446 bool must_compile;
447 OP_REQUIRES_OK_RETURN(ctx, false,
448 ctx->GetAttr("must_compile", &must_compile));
449 return must_compile;
450 }
451
HasRefVars(OpKernelConstruction * ctx)452 bool HasRefVars(OpKernelConstruction* ctx) {
453 bool has_ref_vars;
454 OP_REQUIRES_OK_RETURN(ctx, false,
455 ctx->GetAttr(kXlaHasReferenceVarsAttr, &has_ref_vars));
456 return has_ref_vars;
457 }
458
459 } // namespace
460
XlaLocalLaunchOp(OpKernelConstruction * ctx)461 XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
462 : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
463 FunctionAttr(ctx), /*has_ref_vars=*/true) {}
464
~XlaLocalLaunchOp()465 XlaLocalLaunchOp::~XlaLocalLaunchOp() {
466 VLOG(1) << "XlaLocalLaunchOp destroyed";
467 }
468
XlaCompileOp(OpKernelConstruction * ctx)469 XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
470 : OpKernel(ctx),
471 constants_(ConstantsVector(ctx)),
472 resources_(ResourcesVector(ctx)),
473 function_(FunctionAttr(ctx)),
474 platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
475 must_compile_(MustCompileAttr(ctx)),
476 has_ref_vars_(HasRefVars(ctx)) {}
477
Compute(OpKernelContext * ctx)478 void XlaCompileOp::Compute(OpKernelContext* ctx) {
479 VLOG(3) << "XlaCompileOp " << def().name()
480 << (must_compile_ ? "(must-compile)" : "");
481 xla::LocalClient* client;
482 const XlaCompiler::CompilationResult* kernel;
483 xla::LocalExecutable* executable;
484 ResourceVarsSnapshot variables_snapshot;
485
486 std::vector<const Tensor*> inputs = InputsFromContext(ctx);
487 bool cannot_compile_cluster;
488 {
489 mutex_lock guard(cannot_compile_cluster_mu_);
490 cannot_compile_cluster = cannot_compile_cluster_;
491 }
492 XlaCompilationCache::CompileMode compile_mode = [&] {
493 if (must_compile_) {
494 return XlaCompilationCache::CompileMode::kStrict;
495 }
496 return GetXlaOpsCommonFlags().tf_xla_async_compilation
497 ? XlaCompilationCache::CompileMode::kAsync
498 : XlaCompilationCache::CompileMode::kLazy;
499 }();
500
501 if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation ||
502 cannot_compile_cluster) {
503 executable = nullptr;
504 } else {
505 auto args_and_variables_snapshot = GetXlaCompilerArgsAndSnapshotVariables(
506 resources_, constants_, inputs, ctx);
507 OP_REQUIRES_OK(ctx, args_and_variables_snapshot.status());
508 const std::vector<XlaCompiler::Argument>& args =
509 args_and_variables_snapshot->first;
510 variables_snapshot = std::move(args_and_variables_snapshot->second);
511
512 // Do not alias resource updates as locking variables in XlaCompile and
513 // unlocking them in XlaRun may lead to deadlocks.
514 const Status status = CompileToLocalExecutable(
515 ctx, function_, has_ref_vars_, platform_info_, args, compile_mode,
516 /*may_alias_resource_update=*/false, &client, &kernel, &executable);
517 if (compile_mode != XlaCompilationCache::CompileMode::kLazy ||
518 status.code() != error::UNIMPLEMENTED) {
519 OP_REQUIRES_OK(ctx, status);
520 }
521
522 if (status.code() == error::UNIMPLEMENTED) {
523 LOG(WARNING) << "Compilation failed:" << status.ToString()
524 << ". Falling back to TF function call.";
525
526 BroadcastOptimizationRemark(
527 XlaOptimizationRemark::UNIMPLEMENTED_OPERATION, status.ToString())
528 .IgnoreError();
529 executable = nullptr;
530 mutex_lock guard(cannot_compile_cluster_mu_);
531 cannot_compile_cluster_ = true;
532 }
533 }
534
535 AllocatorAttributes host_alloc_attrs;
536 host_alloc_attrs.set_gpu_compatible(true);
537 host_alloc_attrs.set_on_host(true);
538 Allocator* cpu_allocator = ctx->device()->GetAllocator(host_alloc_attrs);
539
540 // Async compilation returns nullptr executable without an error.
541 if (!executable) {
542 DCHECK(!must_compile_);
543 Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
544
545 Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
546 compilation_successful.scalar<bool>()() = false;
547 ctx->set_output(0, Tensor(cpu_allocator, DT_STRING, TensorShape({})));
548 ctx->set_output(1, compilation_successful);
549 return;
550 }
551
552 // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even
553 // if it didn't have to compile the cluster because of a compilation-cache
554 // hit. This is because we at least need new snapshots of the resource
555 // variables.
556 XlaExecutableClosureStore::KeyT key =
557 XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure(
558 client, executable, kernel, std::move(variables_snapshot),
559 constants_.size()));
560
561 Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
562 compilation_key.flat<tstring>()(0) = key;
563
564 Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
565 compilation_successful.flat<bool>()(0) = true;
566
567 ctx->set_output(0, compilation_key);
568 ctx->set_output(1, compilation_successful);
569 }
570
XlaRunOp(OpKernelConstruction * ctx)571 XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
572 : OpKernel(ctx), platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {}
573
Compute(OpKernelContext * ctx)574 void XlaRunOp::Compute(OpKernelContext* ctx) {
575 VLOG(3) << "XlaRunOp " << def().name();
576 Tensor key_tensor = ctx->input(ctx->num_inputs() - 1);
577 const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<tstring>()(0);
578
579 XlaExecutableClosure closure =
580 XlaExecutableClosureStore::Global()->Consume(key);
581 std::shared_ptr<se::DeviceMemoryAllocator> allocator =
582 GetAllocator(ctx->device(), GetStream(ctx), platform_info_);
583 XlaComputationLaunchContext launch_context =
584 GetLaunchContext(platform_info_, ctx, closure.client(), allocator.get());
585
586 // We're missing the must-be-constant inputs, tell `PopulateInputs`
587 // about this. We don't actually need these inputs because they've
588 // already been baked into the compiled kernel.
589 const xla::HloInputOutputAliasConfig& input_output_alias =
590 closure.executable()->executable()->module().input_output_alias_config();
591 StatusOr<std::vector<xla::ExecutionInput>> execution_inputs;
592 std::map<int, const Tensor*> snapshot_ptrs;
593 {
594 tensorflow::profiler::TraceMe hlo_module_activity(
595 [&] {
596 return absl::StrCat(
597 "Populate Inputs (",
598 closure.compilation_result()->xla_input_shapes.size(), ")");
599 },
600 tensorflow::profiler::TraceMeLevel::kInfo);
601
602 for (const auto& [variable_index, variable_tensor] :
603 closure.resource_var_snapshots()) {
604 snapshot_ptrs.emplace(variable_index, variable_tensor.has_value()
605 ? &variable_tensor.value()
606 : nullptr);
607 }
608 execution_inputs = launch_context.PopulateInputs(
609 ctx, closure.compilation_result(), snapshot_ptrs,
610 /*missing_ctx_input_prefix=*/closure.num_constant_args(),
611 input_output_alias);
612 OP_REQUIRES_OK(ctx, execution_inputs.status());
613 }
614
615 xla::ExecutableRunOptions run_options;
616 StatusOr<xla::ExecutionOutput> execution_output = RunExecutable(
617 platform_info_, launch_context, std::move(*execution_inputs), run_options,
618 closure.executable(), ctx, allocator.get());
619 OP_REQUIRES(ctx, execution_output.ok(), execution_output.status());
620
621 tensorflow::profiler::TraceMe hlo_module_activity(
622 [&] {
623 return absl::StrCat("Populate Outputs (", ctx->num_outputs(), ")");
624 },
625 tensorflow::profiler::TraceMeLevel::kInfo);
626
627 StatusOr<std::vector<VariableInfo>> variable_infos = GatherVariableInfo(
628 ctx, *closure.compilation_result(), closure.num_constant_args());
629 OP_REQUIRES_OK(ctx, variable_infos.status());
630 OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(*variable_infos)));
631 OP_REQUIRES_OK(
632 ctx,
633 launch_context.PopulateOutputs(
634 ctx, closure.compilation_result(), execution_output->ConsumeResult(),
635 /*missing_ctx_input_prefix=*/closure.num_constant_args(),
636 absl::MakeSpan(*variable_infos), input_output_alias, snapshot_ptrs));
637 }
638
XlaMergeOp(OpKernelConstruction * ctx)639 XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
640
Compute(OpKernelContext * ctx)641 void XlaMergeOp::Compute(OpKernelContext* ctx) {
642 VLOG(3) << "XlaMergeOp " << def().name();
643 int i = 0;
644 if (ctx->has_input(i) || ctx->has_input(++i)) {
645 ctx->set_output(0, ctx->input(i));
646 }
647 }
648
649 REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
650
651 REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
652 .Device(DEVICE_GPU)
653 .HostMemory("constants")
654 .HostMemory("resources"),
655 XlaLocalLaunchOp);
656
657 REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp);
658 REGISTER_KERNEL_BUILDER(Name("_XlaCompile")
659 .Device(DEVICE_GPU)
660 .HostMemory("constants")
661 .HostMemory("key")
662 .HostMemory("compilation_successful")
663 .HostMemory("resources"),
664 XlaCompileOp);
665
666 REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp);
667 REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU).HostMemory("key"),
668 XlaRunOp);
669
670 REGISTER_KERNEL_BUILDER(Name("_XlaMerge").Device(DEVICE_CPU), XlaMergeOp);
671 REGISTER_KERNEL_BUILDER(Name("_XlaMerge").Device(DEVICE_GPU), XlaMergeOp);
672
673 } // namespace tensorflow
674