xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/kernels/xla_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/kernels/xla_ops.h"
17 
18 #include <map>
19 #include <memory>
20 #include <tuple>
21 #include <utility>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
26 #include "tensorflow/compiler/jit/flags.h"
27 #include "tensorflow/compiler/jit/xla_activity_listener.h"
28 #include "tensorflow/compiler/jit/xla_platform_info.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
30 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
31 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
32 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
33 #include "tensorflow/compiler/xla/client/local_client.h"
34 #include "tensorflow/compiler/xla/executable_run_options.h"
35 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
36 #include "tensorflow/compiler/xla/statusor.h"
37 #include "tensorflow/core/framework/allocator.h"
38 #include "tensorflow/core/framework/node_def_util.h"
39 #include "tensorflow/core/framework/op_kernel.h"
40 #include "tensorflow/core/framework/op_requires.h"
41 #include "tensorflow/core/framework/tensor.h"
42 #include "tensorflow/core/framework/types.h"
43 #include "tensorflow/core/lib/monitoring/counter.h"
44 #include "tensorflow/core/platform/env.h"
45 #include "tensorflow/core/platform/refcount.h"
46 #include "tensorflow/core/platform/statusor.h"
47 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
48 #include "tensorflow/core/profiler/lib/traceme.h"
49 #include "tensorflow/core/util/stream_executor_util.h"
50 
51 // OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
52 // in error case, it returns RET instead of void.
53 #define OP_REQUIRES_OK_RETURN(CTX, RET, ...)                \
54   do {                                                      \
55     ::tensorflow::Status _s(__VA_ARGS__);                   \
56     if (!TF_PREDICT_TRUE(_s.ok())) {                        \
57       (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
58       return RET;                                           \
59     }                                                       \
60   } while (0)
61 
62 namespace tensorflow {
63 
64 namespace {
65 
66 auto* xla_launch_counter = monitoring::Counter<1>::New(
67     "/tensorflow/core/xla_launch_counter",
68     "The number of times a XlaLaunch is called.", "device");
69 
70 // A closure describing how to run a compiled version of a TensorFlow function.
71 //
72 // It may seem unusual to stick the resource variable snapshots in this class.
73 // This is necessary: we need to use the snapshots observed by the compiler as
74 // the initial values for the resource variables (and cannot snapshot them again
75 // during execution) because otherwise we risk observing a different snapshot
76 // with shapes different from what we compiled for.
77 class XlaExecutableClosure {
78  public:
XlaExecutableClosure(xla::LocalClient * client,xla::LocalExecutable * executable,const XlaCompiler::CompilationResult * compilation_result,ResourceVarsSnapshot resource_var_snapshots,int num_constant_args)79   explicit XlaExecutableClosure(
80       xla::LocalClient* client, xla::LocalExecutable* executable,
81       const XlaCompiler::CompilationResult* compilation_result,
82       ResourceVarsSnapshot resource_var_snapshots, int num_constant_args)
83       : client_(client),
84         executable_(executable),
85         compilation_result_(compilation_result),
86         resource_var_snapshots_(std::move(resource_var_snapshots)),
87         num_constant_args_(num_constant_args) {}
88 
89   XlaExecutableClosure(XlaExecutableClosure&&) = default;
90   XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default;
91 
client() const92   xla::LocalClient* client() const { return client_; }
executable() const93   xla::LocalExecutable* executable() const { return executable_; }
compilation_result() const94   const XlaCompiler::CompilationResult* compilation_result() const {
95     return compilation_result_;
96   }
resource_var_snapshots() const97   const ResourceVarsSnapshot& resource_var_snapshots() const {
98     return resource_var_snapshots_;
99   }
num_constant_args() const100   int num_constant_args() const { return num_constant_args_; }
101 
102  private:
103   xla::LocalClient* client_;
104   xla::LocalExecutable* executable_;
105   const XlaCompiler::CompilationResult* compilation_result_;
106   ResourceVarsSnapshot resource_var_snapshots_;
107   int num_constant_args_;
108 
109   TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
110 };
111 
112 // This maintains a mapping from a globally unique ID to XlaExecutableClosure
113 // instances.
114 class XlaExecutableClosureStore {
115  public:
XlaExecutableClosureStore()116   XlaExecutableClosureStore() : key_counter_(0) {}
117 
118   using KeyT = string;
119 
Produce(XlaExecutableClosure result)120   KeyT Produce(XlaExecutableClosure result) {
121     mutex_lock l(mutex_);
122     KeyT key = absl::StrCat(key_counter_++);
123     bool insert_successful = closures_.emplace(key, std::move(result)).second;
124     DCHECK(insert_successful);
125     (void)insert_successful;
126     return key;
127   }
128 
Consume(const KeyT & key)129   XlaExecutableClosure Consume(const KeyT& key) {
130     mutex_lock l(mutex_);
131     auto it = closures_.find(key);
132     DCHECK(it != closures_.end());
133     XlaExecutableClosure value = std::move(it->second);
134     closures_.erase(it);
135     return value;
136   }
137 
Global()138   static XlaExecutableClosureStore* Global() {
139     static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore;
140     return instance;
141   }
142 
143  private:
144   mutex mutex_;
145   int64_t key_counter_ TF_GUARDED_BY(mutex_);
146   absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_
147       TF_GUARDED_BY(mutex_);
148 
149   TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
150 };
151 
GetStream(OpKernelContext * ctx)152 se::Stream* GetStream(OpKernelContext* ctx) {
153   return ctx->op_device_context() ? ctx->op_device_context()->stream()
154                                   : nullptr;
155 }
156 
GetLaunchContext(const XlaPlatformInfo & platform_info,OpKernelContext * ctx,xla::LocalClient * client,se::DeviceMemoryAllocator * allocator)157 XlaComputationLaunchContext GetLaunchContext(
158     const XlaPlatformInfo& platform_info, OpKernelContext* ctx,
159     xla::LocalClient* client, se::DeviceMemoryAllocator* allocator) {
160   se::Stream* stream = GetStream(ctx);
161   int device_ordinal = stream ? stream->parent()->device_ordinal()
162                               : client->default_device_ordinal();
163   XlaComputationLaunchContext launch_context(
164       client, allocator, device_ordinal,
165       /*allocate_xla_tensors=*/platform_info.is_on_xla_device(),
166       /*use_multiple_streams=*/platform_info.UseMultipleStreams());
167   return launch_context;
168 }
169 
RunExecutable(const XlaPlatformInfo & platform_info,const XlaComputationLaunchContext & launch_context,std::vector<xla::ExecutionInput> execution_inputs,xla::ExecutableRunOptions run_options,xla::LocalExecutable * executable,OpKernelContext * ctx,se::DeviceMemoryAllocator * allocator)170 StatusOr<xla::ExecutionOutput> RunExecutable(
171     const XlaPlatformInfo& platform_info,
172     const XlaComputationLaunchContext& launch_context,
173     std::vector<xla::ExecutionInput> execution_inputs,
174     xla::ExecutableRunOptions run_options, xla::LocalExecutable* executable,
175     OpKernelContext* ctx, se::DeviceMemoryAllocator* allocator) {
176   VLOG(2) << "Executing Xla Computation.";
177   Env* env = Env::Default();
178   auto start_time = env->NowMicros();
179 
180   se::Stream* stream = GetStream(ctx);
181   run_options.set_stream(GetStream(ctx));
182   run_options.set_allocator(allocator);
183   run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
184   run_options.set_rng_seed(GetXLARandomSeed());
185   StatusOr<xla::ExecutionOutput> execution_output;
186   bool run_synchronous =
187       !stream || platform_info.platform_id() == se::host::kHostPlatformId;
188   if (run_synchronous) {
189     execution_output =
190         executable->Run(std::move(execution_inputs), run_options);
191   } else {
192     execution_output =
193         executable->RunAsync(std::move(execution_inputs), run_options);
194   }
195 
196   auto elapsed = env->NowMicros() - start_time;
197   VLOG(2) << "Elapsed time for Xla Executable Run: " << elapsed << "us";
198   return execution_output;
199 }
200 
201 StatusOr<std::pair<std::vector<XlaCompiler::Argument>, ResourceVarsSnapshot>>
GetXlaCompilerArgsAndSnapshotVariables(absl::Span<const int> variable_indices,absl::Span<const int> must_be_constant_idxs,absl::Span<const Tensor * const> inputs,OpKernelContext * ctx)202 GetXlaCompilerArgsAndSnapshotVariables(
203     absl::Span<const int> variable_indices,
204     absl::Span<const int> must_be_constant_idxs,
205     absl::Span<const Tensor* const> inputs, OpKernelContext* ctx) {
206   std::pair<std::vector<XlaCompiler::Argument>, ResourceVarsSnapshot> result;
207 
208   std::vector<VariableInfo> variable_infos;
209   TF_RETURN_IF_ERROR(
210       GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(), inputs,
211                                  variable_indices, &variable_infos));
212   TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
213 
214   TF_RETURN_IF_ERROR(SnapshotResourceVariables(ctx, variable_indices,
215                                                variable_infos, &result.second));
216 
217   TF_ASSIGN_OR_RETURN(result.first,
218                       XlaComputationLaunchContext::BuildXlaCompilerArguments(
219                           must_be_constant_idxs, inputs, variable_infos,
220                           static_cast<Device*>(ctx->device())));
221   return result;
222 }
223 
224 }  // namespace
225 
XlaLocalLaunchBase(OpKernelConstruction * ctx,const std::vector<int> & constants,const std::vector<int> & resources,const NameAttrList & function,bool has_ref_vars)226 XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
227                                        const std::vector<int>& constants,
228                                        const std::vector<int>& resources,
229                                        const NameAttrList& function,
230                                        bool has_ref_vars)
231     : AsyncOpKernel(ctx),
232       constants_(constants),
233       resources_(resources),
234       function_(function),
235       platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
236       has_ref_vars_(has_ref_vars) {}
237 
CompileToLocalExecutable(OpKernelContext * ctx,const NameAttrList & function,bool has_ref_vars,const XlaPlatformInfo & platform_info,const std::vector<XlaCompiler::Argument> & args,XlaCompilationCache::CompileMode compile_mode,bool may_alias_resource_update,xla::LocalClient ** client,const XlaCompiler::CompilationResult ** compilation_result,xla::LocalExecutable ** executable)238 static Status CompileToLocalExecutable(
239     OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
240     const XlaPlatformInfo& platform_info,
241     const std::vector<XlaCompiler::Argument>& args,
242     XlaCompilationCache::CompileMode compile_mode,
243     bool may_alias_resource_update, xla::LocalClient** client,
244     const XlaCompiler::CompilationResult** compilation_result,
245     xla::LocalExecutable** executable) {
246   // We store information about the JIT-compiled XLA computation
247   // in the ResourceMgr.
248   ResourceMgr* rm = ctx->resource_manager();
249   if (!rm) {
250     return errors::Internal("No resource manager.");
251   }
252 
253   XlaCompilationCache* cache;
254   TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
255       rm->default_container(), "xla_cache", &cache,
256       [&](XlaCompilationCache** cache) {
257         return BuildXlaCompilationCache(ctx->device(), ctx->function_library(),
258                                         platform_info, cache);
259       }));
260   // Hold the reference to the JIT during evaluation. (We could probably
261   // free it sooner because the ResourceMgr will retain a reference, but
262   // this is more obviously correct.)
263   core::ScopedUnref cache_ref(cache);
264 
265   *client = static_cast<xla::LocalClient*>(cache->client());
266 
267   XlaCompiler::Options options =
268       GenerateCompilerOptions(*cache, *ctx->function_library(), ctx->device(),
269                               GetStream(ctx), platform_info, has_ref_vars);
270 
271   XlaCompiler::CompileOptions compile_options;
272   compile_options.is_entry_computation = true;
273   // Optimization: where possible, have the computation return a naked array
274   // rather than a one-element tuple.
275   compile_options.always_return_tuple = false;
276   compile_options.alias_resource_update =
277       !has_ref_vars && may_alias_resource_update;
278 
279   return cache->Compile(options, function, args, compile_options, compile_mode,
280                         compilation_result, executable);
281 }
282 
283 // Get-or-create thread pool for a given collective.
GetOrCreateThreadPoolForCollective(const XlaCompilationResult::CollectiveInfo & collective_info)284 static thread::ThreadPool* GetOrCreateThreadPoolForCollective(
285     const XlaCompilationResult::CollectiveInfo& collective_info) {
286   static absl::Mutex m(absl::kConstInit);
287   static auto& thread_pool_cache ABSL_GUARDED_BY(m) =
288       *new absl::node_hash_map<XlaCompilationResult::CollectiveInfo,
289                                thread::ThreadPool>();
290   absl::MutexLock l(&m);
291   auto it = thread_pool_cache.find(collective_info);
292   if (it == thread_pool_cache.end()) {
293     // Create & cache thread pool.
294     auto inserted_it = thread_pool_cache.emplace(
295         std::piecewise_construct, std::forward_as_tuple(collective_info),
296         std::forward_as_tuple(Env::Default(), "xla_collective_thread_pool",
297                               collective_info.group_size));
298     return &inserted_it.first->second;
299   }
300   return &it->second;
301 }
302 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)303 void XlaLocalLaunchBase::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
304   VLOG(1) << "XlaLocalLaunchOpBase::Compute "
305           << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
306   xla_launch_counter->GetCell(platform_info_.device_type().type_string())
307       ->IncrementBy(1);
308 
309   std::vector<const Tensor*> inputs = InputsFromContext(ctx);
310   xla::LocalClient* client;
311   const XlaCompiler::CompilationResult* compilation_result;
312   xla::LocalExecutable* executable;
313 
314   auto args_and_variables_snapshot = GetXlaCompilerArgsAndSnapshotVariables(
315       resources_, constants_, inputs, ctx);
316   OP_REQUIRES_OK_ASYNC(ctx, args_and_variables_snapshot.status(), done);
317   const std::vector<XlaCompiler::Argument>& args =
318       args_and_variables_snapshot->first;
319   ResourceVarsSnapshot& variables_snapshot =
320       args_and_variables_snapshot->second;
321 
322   const Status s = CompileToLocalExecutable(
323       ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, args,
324       XlaCompilationCache::CompileMode::kStrict,
325       /*may_alias_resource_update=*/true, &client, &compilation_result,
326       &executable);
327   OP_REQUIRES_OK_ASYNC(ctx, s, done);
328 
329   // Continuation of the execution, may be run in a different thread.
330   auto run_xla_cluster = [ctx, variables_snapshot, client, executable,
331                           compilation_result, done, inputs,
332                           resources = resources_] {
333     auto platform_info = XlaPlatformInfoFromDevice(ctx->device());
334     std::map<int, const Tensor*> snapshot_ptrs;
335     for (const auto& [variable_index, variable_tensor] : variables_snapshot) {
336       snapshot_ptrs.emplace(variable_index, variable_tensor.has_value()
337                                                 ? &variable_tensor.value()
338                                                 : nullptr);
339     }
340 
341     std::shared_ptr<se::DeviceMemoryAllocator> allocator =
342         GetAllocator(ctx->device(), GetStream(ctx), platform_info);
343     XlaComputationLaunchContext launch_context =
344         GetLaunchContext(platform_info, ctx, client, allocator.get());
345 
346     const xla::HloInputOutputAliasConfig& input_output_alias =
347         executable->executable()->module().input_output_alias_config();
348     StatusOr<std::vector<xla::ExecutionInput>> execution_inputs =
349         launch_context.PopulateInputs(ctx, compilation_result, snapshot_ptrs,
350                                       /*missing_ctx_input_prefix=*/0,
351                                       input_output_alias);
352     OP_REQUIRES_OK_ASYNC(ctx, execution_inputs.status(), done);
353 
354     xla::gpu::GpuExecutableRunOptions gpu_options;
355     xla::DeviceAssignment device_assignment;
356     xla::ExecutableRunOptions run_options;
357     if (compilation_result->collective_info.has_value()) {
358       OP_REQUIRES_OK_ASYNC(
359           ctx,
360           ResolveDeviceAssignment(ctx, *compilation_result->collective_info,
361                                   run_options, device_assignment, gpu_options),
362           done);
363     }
364 
365     // Hardcode run id to always be zero: TF distributed strategy differentiates
366     // between subsequent runs using dependency edges.
367     // This is safe, as only TF dist-strat can produce distributed ops, and we
368     // can rely on TF dist-strat invariants.
369     xla::RunId run_id(0);
370     run_options.set_run_id(run_id);
371 
372     StatusOr<xla::ExecutionOutput> execution_output = RunExecutable(
373         platform_info, launch_context, std::move(*execution_inputs),
374         run_options, executable, ctx, allocator.get());
375     OP_REQUIRES_ASYNC(ctx, execution_output.ok(), execution_output.status(),
376                       done);
377 
378     std::vector<VariableInfo> variable_infos;
379     OP_REQUIRES_OK_ASYNC(
380         ctx,
381         GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(),
382                                    inputs, resources, &variable_infos),
383         done);
384     OP_REQUIRES_OK_ASYNC(ctx, LockVariables(absl::MakeSpan(variable_infos)),
385                          done);
386 
387     OP_REQUIRES_OK_ASYNC(
388         ctx,
389         launch_context.PopulateOutputs(
390             ctx, compilation_result, execution_output->ConsumeResult(),
391             /*missing_ctx_input_prefix=*/0, absl::MakeSpan(variable_infos),
392             input_output_alias, snapshot_ptrs),
393         done);
394     VLOG(1) << "Done";
395     done();
396   };
397 
398   // If we are using collectives, we need to run in a separate threadpool.
399   if (compilation_result->collective_info.has_value()) {
400     GetOrCreateThreadPoolForCollective(*compilation_result->collective_info)
401         ->Schedule(run_xla_cluster);
402   } else {
403     // Otherwise, just run normally: we merely "pretend" to be asynchronous.
404     run_xla_cluster();
405   }
406 }
407 
408 namespace {
409 // Helper static functions to construct parameters for
410 // XlaLocalLaunchBase constructor from OpKernelConstruction.
ConstantsVector(OpKernelConstruction * ctx)411 std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
412   DataTypeVector constant_types;
413   OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
414                         ctx->GetAttr("Tconstants", &constant_types));
415   std::vector<int> constants(constant_types.size());
416   std::iota(constants.begin(), constants.end(), 0);
417   return constants;
418 }
419 
ResourcesVector(OpKernelConstruction * ctx)420 std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
421   DataTypeVector constant_types;
422   OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
423                         ctx->GetAttr("Tconstants", &constant_types));
424 
425   DataTypeVector arg_types;
426   OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
427                         ctx->GetAttr("Targs", &arg_types));
428 
429   int num_resources;
430   OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
431                         ctx->GetAttr("Nresources", &num_resources));
432 
433   std::vector<int> resources(num_resources);
434   std::iota(resources.begin(), resources.end(),
435             constant_types.size() + arg_types.size());
436   return resources;
437 }
438 
FunctionAttr(OpKernelConstruction * ctx)439 NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
440   const NameAttrList* func;
441   OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
442   return *func;
443 }
444 
MustCompileAttr(OpKernelConstruction * ctx)445 bool MustCompileAttr(OpKernelConstruction* ctx) {
446   bool must_compile;
447   OP_REQUIRES_OK_RETURN(ctx, false,
448                         ctx->GetAttr("must_compile", &must_compile));
449   return must_compile;
450 }
451 
HasRefVars(OpKernelConstruction * ctx)452 bool HasRefVars(OpKernelConstruction* ctx) {
453   bool has_ref_vars;
454   OP_REQUIRES_OK_RETURN(ctx, false,
455                         ctx->GetAttr(kXlaHasReferenceVarsAttr, &has_ref_vars));
456   return has_ref_vars;
457 }
458 
459 }  // namespace
460 
XlaLocalLaunchOp(OpKernelConstruction * ctx)461 XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
462     : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
463                          FunctionAttr(ctx), /*has_ref_vars=*/true) {}
464 
~XlaLocalLaunchOp()465 XlaLocalLaunchOp::~XlaLocalLaunchOp() {
466   VLOG(1) << "XlaLocalLaunchOp destroyed";
467 }
468 
XlaCompileOp(OpKernelConstruction * ctx)469 XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
470     : OpKernel(ctx),
471       constants_(ConstantsVector(ctx)),
472       resources_(ResourcesVector(ctx)),
473       function_(FunctionAttr(ctx)),
474       platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
475       must_compile_(MustCompileAttr(ctx)),
476       has_ref_vars_(HasRefVars(ctx)) {}
477 
Compute(OpKernelContext * ctx)478 void XlaCompileOp::Compute(OpKernelContext* ctx) {
479   VLOG(3) << "XlaCompileOp " << def().name()
480           << (must_compile_ ? "(must-compile)" : "");
481   xla::LocalClient* client;
482   const XlaCompiler::CompilationResult* kernel;
483   xla::LocalExecutable* executable;
484   ResourceVarsSnapshot variables_snapshot;
485 
486   std::vector<const Tensor*> inputs = InputsFromContext(ctx);
487   bool cannot_compile_cluster;
488   {
489     mutex_lock guard(cannot_compile_cluster_mu_);
490     cannot_compile_cluster = cannot_compile_cluster_;
491   }
492   XlaCompilationCache::CompileMode compile_mode = [&] {
493     if (must_compile_) {
494       return XlaCompilationCache::CompileMode::kStrict;
495     }
496     return GetXlaOpsCommonFlags().tf_xla_async_compilation
497                ? XlaCompilationCache::CompileMode::kAsync
498                : XlaCompilationCache::CompileMode::kLazy;
499   }();
500 
501   if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation ||
502       cannot_compile_cluster) {
503     executable = nullptr;
504   } else {
505     auto args_and_variables_snapshot = GetXlaCompilerArgsAndSnapshotVariables(
506         resources_, constants_, inputs, ctx);
507     OP_REQUIRES_OK(ctx, args_and_variables_snapshot.status());
508     const std::vector<XlaCompiler::Argument>& args =
509         args_and_variables_snapshot->first;
510     variables_snapshot = std::move(args_and_variables_snapshot->second);
511 
512     // Do not alias resource updates as locking variables in XlaCompile and
513     // unlocking them in XlaRun may lead to deadlocks.
514     const Status status = CompileToLocalExecutable(
515         ctx, function_, has_ref_vars_, platform_info_, args, compile_mode,
516         /*may_alias_resource_update=*/false, &client, &kernel, &executable);
517     if (compile_mode != XlaCompilationCache::CompileMode::kLazy ||
518         status.code() != error::UNIMPLEMENTED) {
519       OP_REQUIRES_OK(ctx, status);
520     }
521 
522     if (status.code() == error::UNIMPLEMENTED) {
523       LOG(WARNING) << "Compilation failed:" << status.ToString()
524                    << ".  Falling back to TF function call.";
525 
526       BroadcastOptimizationRemark(
527           XlaOptimizationRemark::UNIMPLEMENTED_OPERATION, status.ToString())
528           .IgnoreError();
529       executable = nullptr;
530       mutex_lock guard(cannot_compile_cluster_mu_);
531       cannot_compile_cluster_ = true;
532     }
533   }
534 
535   AllocatorAttributes host_alloc_attrs;
536   host_alloc_attrs.set_gpu_compatible(true);
537   host_alloc_attrs.set_on_host(true);
538   Allocator* cpu_allocator = ctx->device()->GetAllocator(host_alloc_attrs);
539 
540   // Async compilation returns nullptr executable without an error.
541   if (!executable) {
542     DCHECK(!must_compile_);
543     Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
544 
545     Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
546     compilation_successful.scalar<bool>()() = false;
547     ctx->set_output(0, Tensor(cpu_allocator, DT_STRING, TensorShape({})));
548     ctx->set_output(1, compilation_successful);
549     return;
550   }
551 
552   // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even
553   // if it didn't have to compile the cluster because of a compilation-cache
554   // hit.  This is because we at least need new snapshots of the resource
555   // variables.
556   XlaExecutableClosureStore::KeyT key =
557       XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure(
558           client, executable, kernel, std::move(variables_snapshot),
559           constants_.size()));
560 
561   Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
562   compilation_key.flat<tstring>()(0) = key;
563 
564   Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
565   compilation_successful.flat<bool>()(0) = true;
566 
567   ctx->set_output(0, compilation_key);
568   ctx->set_output(1, compilation_successful);
569 }
570 
XlaRunOp(OpKernelConstruction * ctx)571 XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
572     : OpKernel(ctx), platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {}
573 
Compute(OpKernelContext * ctx)574 void XlaRunOp::Compute(OpKernelContext* ctx) {
575   VLOG(3) << "XlaRunOp " << def().name();
576   Tensor key_tensor = ctx->input(ctx->num_inputs() - 1);
577   const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<tstring>()(0);
578 
579   XlaExecutableClosure closure =
580       XlaExecutableClosureStore::Global()->Consume(key);
581   std::shared_ptr<se::DeviceMemoryAllocator> allocator =
582       GetAllocator(ctx->device(), GetStream(ctx), platform_info_);
583   XlaComputationLaunchContext launch_context =
584       GetLaunchContext(platform_info_, ctx, closure.client(), allocator.get());
585 
586   // We're missing the must-be-constant inputs, tell `PopulateInputs`
587   // about this.  We don't actually need these inputs because they've
588   // already been baked into the compiled kernel.
589   const xla::HloInputOutputAliasConfig& input_output_alias =
590       closure.executable()->executable()->module().input_output_alias_config();
591   StatusOr<std::vector<xla::ExecutionInput>> execution_inputs;
592   std::map<int, const Tensor*> snapshot_ptrs;
593   {
594     tensorflow::profiler::TraceMe hlo_module_activity(
595         [&] {
596           return absl::StrCat(
597               "Populate Inputs (",
598               closure.compilation_result()->xla_input_shapes.size(), ")");
599         },
600         tensorflow::profiler::TraceMeLevel::kInfo);
601 
602     for (const auto& [variable_index, variable_tensor] :
603          closure.resource_var_snapshots()) {
604       snapshot_ptrs.emplace(variable_index, variable_tensor.has_value()
605                                                 ? &variable_tensor.value()
606                                                 : nullptr);
607     }
608     execution_inputs = launch_context.PopulateInputs(
609         ctx, closure.compilation_result(), snapshot_ptrs,
610         /*missing_ctx_input_prefix=*/closure.num_constant_args(),
611         input_output_alias);
612     OP_REQUIRES_OK(ctx, execution_inputs.status());
613   }
614 
615   xla::ExecutableRunOptions run_options;
616   StatusOr<xla::ExecutionOutput> execution_output = RunExecutable(
617       platform_info_, launch_context, std::move(*execution_inputs), run_options,
618       closure.executable(), ctx, allocator.get());
619   OP_REQUIRES(ctx, execution_output.ok(), execution_output.status());
620 
621   tensorflow::profiler::TraceMe hlo_module_activity(
622       [&] {
623         return absl::StrCat("Populate Outputs (", ctx->num_outputs(), ")");
624       },
625       tensorflow::profiler::TraceMeLevel::kInfo);
626 
627   StatusOr<std::vector<VariableInfo>> variable_infos = GatherVariableInfo(
628       ctx, *closure.compilation_result(), closure.num_constant_args());
629   OP_REQUIRES_OK(ctx, variable_infos.status());
630   OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(*variable_infos)));
631   OP_REQUIRES_OK(
632       ctx,
633       launch_context.PopulateOutputs(
634           ctx, closure.compilation_result(), execution_output->ConsumeResult(),
635           /*missing_ctx_input_prefix=*/closure.num_constant_args(),
636           absl::MakeSpan(*variable_infos), input_output_alias, snapshot_ptrs));
637 }
638 
XlaMergeOp(OpKernelConstruction * ctx)639 XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
640 
Compute(OpKernelContext * ctx)641 void XlaMergeOp::Compute(OpKernelContext* ctx) {
642   VLOG(3) << "XlaMergeOp " << def().name();
643   int i = 0;
644   if (ctx->has_input(i) || ctx->has_input(++i)) {
645     ctx->set_output(0, ctx->input(i));
646   }
647 }
648 
649 REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
650 
651 REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
652                             .Device(DEVICE_GPU)
653                             .HostMemory("constants")
654                             .HostMemory("resources"),
655                         XlaLocalLaunchOp);
656 
657 REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp);
658 REGISTER_KERNEL_BUILDER(Name("_XlaCompile")
659                             .Device(DEVICE_GPU)
660                             .HostMemory("constants")
661                             .HostMemory("key")
662                             .HostMemory("compilation_successful")
663                             .HostMemory("resources"),
664                         XlaCompileOp);
665 
666 REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp);
667 REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU).HostMemory("key"),
668                         XlaRunOp);
669 
670 REGISTER_KERNEL_BUILDER(Name("_XlaMerge").Device(DEVICE_CPU), XlaMergeOp);
671 REGISTER_KERNEL_BUILDER(Name("_XlaMerge").Device(DEVICE_GPU), XlaMergeOp);
672 
673 }  // namespace tensorflow
674