xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/function.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/function.h"
17 
18 #include <deque>
19 #include <vector>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/core/common_runtime/device.h"
25 #include "tensorflow/core/common_runtime/executor.h"
26 #include "tensorflow/core/common_runtime/executor_factory.h"
27 #include "tensorflow/core/common_runtime/gradients.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/common_runtime/graph_optimizer.h"
30 #include "tensorflow/core/common_runtime/inline_function_utils.h"
31 #include "tensorflow/core/common_runtime/memory_types.h"
32 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
33 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
34 #include "tensorflow/core/common_runtime/single_threaded_executor.h"
35 #include "tensorflow/core/framework/collective.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/framework/function_handle_cache.h"
38 #include "tensorflow/core/framework/metrics.h"
39 #include "tensorflow/core/framework/node_def.pb.h"
40 #include "tensorflow/core/framework/node_def_util.h"
41 #include "tensorflow/core/framework/op.h"
42 #include "tensorflow/core/framework/op_kernel.h"
43 #include "tensorflow/core/framework/versions.pb.h"
44 #include "tensorflow/core/graph/algorithm.h"
45 #include "tensorflow/core/graph/control_flow.h"
46 #include "tensorflow/core/graph/node_builder.h"
47 #include "tensorflow/core/graph/optimizer_cse.h"
48 #include "tensorflow/core/lib/core/threadpool.h"
49 #include "tensorflow/core/lib/gtl/map_util.h"
50 #include "tensorflow/core/platform/macros.h"
51 #include "tensorflow/core/platform/str_util.h"
52 #include "tensorflow/core/profiler/lib/connected_traceme.h"
53 #include "tensorflow/core/profiler/lib/traceme.h"
54 #include "tensorflow/core/protobuf/config.pb.h"
55 
56 // See core/kernels/function_ops.cc for related kernels.
57 
58 namespace tensorflow {
59 
60 // A few string constant used throughout this module.
61 static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
62 static constexpr const char* const kDeviceArgOp =
63     FunctionLibraryDefinition::kDeviceArgOp;
64 static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
65 static constexpr const char* const kDeviceRetOp =
66     FunctionLibraryDefinition::kDeviceRetOp;
67 static constexpr const char* const kGradientOp =
68     FunctionLibraryDefinition::kGradientOp;
69 static constexpr const char* const kNodeLabel = "Func";
70 static constexpr const char* const kFuncAttr =
71     FunctionLibraryDefinition::kFuncAttr;
72 
73 // Represents the index-th output of a node.
74 struct Endpoint {
75   Node* node;
76   int index;
77 
78   // Returns the string name represents this endpoint.
nametensorflow::Endpoint79   string name() const {
80     if (index == 0) {
81       return node->name();
82     } else {
83       return strings::StrCat(node->name(), ":", index);
84     }
85   }
86 
dtypetensorflow::Endpoint87   DataType dtype() const { return node->output_type(index); }
88 };
89 
90 struct EndpointHash {
operator ()tensorflow::EndpointHash91   uint64 operator()(const Endpoint& x) const {
92     return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
93                   x.index);
94   }
95 };
96 
97 struct EndpointEq {
operator ()tensorflow::EndpointEq98   bool operator()(const Endpoint& x, const Endpoint& y) const {
99     return (x.node == y.node) && (x.index == y.index);
100   }
101 };
102 
103 // The following Add* routines are used to add a few graph nodes while
104 // functions are transformed.
AddArg(Graph * g,DataType dtype,int index)105 static Node* AddArg(Graph* g, DataType dtype, int index) {
106   DCHECK_LT(0, dtype);
107   DCHECK_LT(dtype, DT_FLOAT_REF);
108   NodeDef ndef;
109   ndef.set_name(g->NewName(kNodeLabel));
110   ndef.set_op(kArgOp);
111   AddNodeAttr("T", dtype, &ndef);
112   AddNodeAttr("index", index, &ndef);
113   Status s;
114   Node* ret = g->AddNode(ndef, &s);
115   TF_CHECK_OK(s);
116   return ret;
117 }
118 
AddRet(Graph * g,Endpoint input,int index)119 static Node* AddRet(Graph* g, Endpoint input, int index) {
120   DCHECK_LT(0, input.dtype());
121   DCHECK_LT(input.dtype(), DT_FLOAT_REF);
122   NodeDef ndef;
123   ndef.set_name(g->NewName(kNodeLabel));
124   ndef.set_op(kRetOp);
125   ndef.add_input(input.name());
126   AddNodeAttr("T", input.dtype(), &ndef);
127   AddNodeAttr("index", index, &ndef);
128   Status s;
129   Node* ret = g->AddNode(ndef, &s);
130   TF_CHECK_OK(s);
131   g->AddEdge(input.node, input.index, ret, 0);
132   return ret;
133 }
134 
135 // FunctionLibraryRuntime implementation that forwards all the function calls to
136 // the base runtime implementation, and only overrides FunctionLibraryDefinition
137 // in calls to Instantiate (if caller doesn't provide the
138 // InstantiateOptions::lib_def option).
139 //
140 // When the function library runtime (FunctionLibraryRuntimeImpl specifically)
141 // instantiates a function into a Graph object, it also creates an Executor for
142 // it. That executor has a pointer to the function library runtime instance,
143 // that is used to instantiate all nested function calls.
144 //
145 // The function library definition used to instantiate the function must be
146 // preserved in the Executor's function library runtime.
147 //
148 // IMPORTANT: This runtime is intended for use only in executors created for
149 // functions instantiated into a graph in FunctionLibraryRuntimeImpl.
150 class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime {
151  public:
FunctionLibraryRuntimeOverlay(FunctionLibraryRuntime * base_flr,const FunctionLibraryDefinition * lib_def)152   FunctionLibraryRuntimeOverlay(FunctionLibraryRuntime* base_flr,
153                                 const FunctionLibraryDefinition* lib_def)
154       : base_flr_(base_flr), lib_def_(lib_def) {}
155   ~FunctionLibraryRuntimeOverlay() override;
156 
157   Status Instantiate(const string& function_name, AttrSlice attrs,
158                      const InstantiateOptions& options,
159                      Handle* handle) override;
160 
161   Status ReleaseHandle(Handle handle) override;
162 
163   const FunctionBody* GetFunctionBody(Handle h) override;
164 
165   Status GetRetTypes(Handle h, DataTypeVector* ret_types) override;
166 
167   void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
168            std::vector<Tensor>* rets, DoneCallback done) override;
169 
170   void Run(const Options& opts, Handle handle, CallFrameInterface* call_frame,
171            DoneCallback done) override;
172 
173   Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
174                  std::vector<Tensor>* rets) override;
175 
176   Status RunSync(Options opts, Handle handle,
177                  CallFrameInterface* frame) override;
178 
179   Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
180                       OpKernel** kernel) override;
181 
182   bool IsStateful(const string& function_name) const override;
183 
184   const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
185       const override;
186 
187   Env* env() override;
188   const ConfigProto* const config_proto() override;
189   Device* device() override;
190   const Device* device() const override;
191   std::function<void(std::function<void()>)>* runner() override;
192   const DeviceMgr* device_mgr() const override;
193 
194   string DebugString(Handle handle) override;
195   int graph_def_version() const override;
196 
197   Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
198                std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
199                FunctionLibraryRuntime** out_flr,
200                bool skip_flib_def = false) override;
201 
202  private:
203   FunctionLibraryRuntime* base_flr_;          // not owned
204   const FunctionLibraryDefinition* lib_def_;  // not owned
205 };
206 
207 FunctionLibraryRuntimeOverlay::~FunctionLibraryRuntimeOverlay() = default;
208 
Instantiate(const string & function_name,AttrSlice attrs,const InstantiateOptions & options,Handle * handle)209 Status FunctionLibraryRuntimeOverlay::Instantiate(
210     const string& function_name, AttrSlice attrs,
211     const InstantiateOptions& options, Handle* handle) {
212   // We automatically set the `lib_def` option for all instantiations, if the
213   // caller doesn't set this option explicitly.
214   if (!options.lib_def && lib_def_) {
215     InstantiateOptions options_copy = options;
216     options_copy.lib_def = lib_def_;
217     return base_flr_->Instantiate(function_name, attrs, options_copy, handle);
218   } else {
219     return base_flr_->Instantiate(function_name, attrs, options, handle);
220   }
221 }
222 
ReleaseHandle(Handle handle)223 Status FunctionLibraryRuntimeOverlay::ReleaseHandle(Handle handle) {
224   return base_flr_->ReleaseHandle(handle);
225 }
226 
GetFunctionBody(Handle h)227 const FunctionBody* FunctionLibraryRuntimeOverlay::GetFunctionBody(Handle h) {
228   return base_flr_->GetFunctionBody(h);
229 }
230 
GetRetTypes(Handle h,DataTypeVector * ret_types)231 Status FunctionLibraryRuntimeOverlay::GetRetTypes(Handle h,
232                                                   DataTypeVector* ret_types) {
233   return base_flr_->GetRetTypes(h, ret_types);
234 }
235 
Run(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,DoneCallback done)236 void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle,
237                                         gtl::ArraySlice<Tensor> args,
238                                         std::vector<Tensor>* rets,
239                                         DoneCallback done) {
240   base_flr_->Run(opts, handle, args, rets, std::move(done));
241 }
242 
Run(const Options & opts,Handle handle,CallFrameInterface * call_frame,DoneCallback done)243 void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle,
244                                         CallFrameInterface* call_frame,
245                                         DoneCallback done) {
246   base_flr_->Run(opts, handle, call_frame, std::move(done));
247 }
248 
RunSync(Options opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets)249 Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
250                                               gtl::ArraySlice<Tensor> args,
251                                               std::vector<Tensor>* rets) {
252   return base_flr_->RunSync(std::move(opts), handle, args, rets);
253 }
254 
RunSync(Options opts,Handle handle,CallFrameInterface * call_frame)255 Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
256                                               CallFrameInterface* call_frame) {
257   return base_flr_->RunSync(std::move(opts), handle, call_frame);
258 }
259 
CreateKernel(const std::shared_ptr<const NodeProperties> &,OpKernel **)260 Status FunctionLibraryRuntimeOverlay::CreateKernel(
261     const std::shared_ptr<const NodeProperties>&, OpKernel**) {
262   // We don't have access to base_lib_def_ in base function library runtime (aka
263   // FunctionLibraryRuntimeImpl), so to make sure we do not create a kernel with
264   // the wrong lib_def we just disable creation of new kernels through overlays.
265   //
266   // When we call Instantiate from the base runtime with the lib_def option,
267   // the base runtime implementation is responsible for correctly passing it
268   // through to all kernel constructions.
269   return errors::Internal(
270       "Overlay function library runtime doesn't support kernel creation.");
271 }
272 
IsStateful(const string & function_name) const273 bool FunctionLibraryRuntimeOverlay::IsStateful(
274     const string& function_name) const {
275   // Important: we do not forward lookup to the base FLR.
276   const OpDef* op_def;
277   const Status s = lib_def_->LookUpOpDef(function_name, &op_def);
278   return s.ok() && op_def->is_stateful();
279 }
280 
env()281 Env* FunctionLibraryRuntimeOverlay::env() { return base_flr_->env(); }
282 
config_proto()283 const ConfigProto* const FunctionLibraryRuntimeOverlay::config_proto() {
284   return base_flr_->config_proto();
285 }
286 
device()287 Device* FunctionLibraryRuntimeOverlay::device() { return base_flr_->device(); }
288 
device() const289 const Device* FunctionLibraryRuntimeOverlay::device() const {
290   return base_flr_->device();
291 }
292 
293 std::function<void(std::function<void()>)>*
runner()294 FunctionLibraryRuntimeOverlay::runner() {
295   return base_flr_->runner();
296 }
297 
device_mgr() const298 const DeviceMgr* FunctionLibraryRuntimeOverlay::device_mgr() const {
299   return base_flr_->device_mgr();
300 }
301 
302 const FunctionLibraryDefinition*
GetFunctionLibraryDefinition() const303 FunctionLibraryRuntimeOverlay::GetFunctionLibraryDefinition() const {
304   return lib_def_ ? lib_def_ : base_flr_->GetFunctionLibraryDefinition();
305 }
306 
DebugString(Handle handle)307 string FunctionLibraryRuntimeOverlay::DebugString(Handle handle) {
308   return base_flr_->DebugString(handle);
309 }
310 
graph_def_version() const311 int FunctionLibraryRuntimeOverlay::graph_def_version() const {
312   return base_flr_->graph_def_version();
313 }
314 
Clone(std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr,FunctionLibraryRuntime ** out_flr,bool skip_flib_def)315 Status FunctionLibraryRuntimeOverlay::Clone(
316     std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
317     std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
318     FunctionLibraryRuntime** out_flr, bool skip_flib_def) {
319   // NOTE(ezhulenev): The cloned FunctionLibraryRuntime will be missing the
320   // FunctionLibraryDefinition override, but that's ok because we anyway do not
321   // copy / clone instantiated items from the base FLR.
322   return base_flr_->Clone(out_lib_def, out_pflr, out_flr, skip_flib_def);
323 }
324 
325 class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
326  public:
327   FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env,
328                              const ConfigProto* config, Device* device,
329                              int graph_def_version,
330                              const FunctionLibraryDefinition* lib_def,
331                              thread::ThreadPool* default_thread_pool,
332                              const OptimizerOptions& optimizer_options,
333                              const SessionMetadata* session_metadata,
334                              ProcessFunctionLibraryRuntime* parent);
335 
336   ~FunctionLibraryRuntimeImpl() override;
337 
338   Status Instantiate(const string& function_name, AttrSlice attrs,
339                      const InstantiateOptions& options,
340                      Handle* handle) override;
341 
342   Status ReleaseHandle(Handle handle) override;
343 
344   const FunctionBody* GetFunctionBody(Handle handle) override;
345 
346   Status GetRetTypes(Handle handle, DataTypeVector* ret_types) override;
347 
348   Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
349                       OpKernel** kernel) override;
350 
351   void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
352            std::vector<Tensor>* rets, DoneCallback done) override;
353   void Run(const Options& opts, Handle handle, CallFrameInterface* frame,
354            DoneCallback done) override;
355   Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
356                  std::vector<Tensor>* rets) override;
357   Status RunSync(Options opts, Handle handle,
358                  CallFrameInterface* call_frame) override;
359 
360   bool IsStateful(const string& function) const override;
361 
GetFunctionLibraryDefinition() const362   const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
363       const override {
364     return base_lib_def_;
365   }
366 
device()367   Device* device() override { return device_; }
device() const368   const Device* device() const override { return device_; }
369 
runner()370   std::function<void(std::function<void()>)>* runner() override {
371     return &default_runner_;
372   }
373 
device_mgr() const374   const DeviceMgr* device_mgr() const override { return device_mgr_; }
env()375   Env* env() override { return env_; }
config_proto()376   const ConfigProto* const config_proto() override { return config_; }
graph_def_version() const377   int graph_def_version() const override { return graph_def_version_; }
378 
379   string DebugString(Handle h) override;
380 
381   Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
382                std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
383                FunctionLibraryRuntime** out_flr,
384                bool skip_flib_def = false) override;
385 
386  private:
387   typedef FunctionLibraryRuntimeImpl ME;
388 
389   const DeviceMgr* const device_mgr_;
390   Device* const device_;
391   Env* const env_;
392   const ConfigProto* const config_;
393   const int graph_def_version_;
394   const FunctionLibraryDefinition* const base_lib_def_;
395   GraphOptimizer optimizer_;
396   const SessionMetadata* const session_metadata_;
397   Executor::Args::Runner default_runner_;
398   const string device_name_;
399 
400   std::function<Status(const string&, const OpDef**)> get_func_sig_;
401   std::function<Status(const std::shared_ptr<const NodeProperties>&,
402                        OpKernel**)>
403       create_kernel_;
404 
405   mutable mutex mu_;
406 
407   int next_handle_ TF_GUARDED_BY(mu_);
408 
409   // The instantiated and transformed function is encoded as a Graph
410   // object, and an executor is created for the graph.
411   struct Item {
412     uint64 instantiation_counter = 0;
413     std::unique_ptr<const Graph> graph = nullptr;
414     const FunctionLibraryDefinition* lib_def = nullptr;  // Not owned.
415     FunctionBody* func_graph = nullptr;
416     Executor* exec = nullptr;
417     FunctionLibraryRuntimeOverlay* overlay_flr = nullptr;
418     string executor_type;
419     bool allow_small_function_optimizations = false;
420     bool allow_control_flow_sync_execution = false;
421 
~Itemtensorflow::FunctionLibraryRuntimeImpl::Item422     ~Item() {
423       delete this->func_graph;
424       delete this->exec;
425       delete this->overlay_flr;
426     }
427   };
428   std::unique_ptr<absl::flat_hash_map<Handle, std::unique_ptr<Item>>> items_
429       TF_GUARDED_BY(mu_);
430   std::unique_ptr<FunctionHandleCache> function_handle_cache_;
431   ProcessFunctionLibraryRuntime* parent_ = nullptr;  // not owned.
432 
433   // Overloads the CreateKernel method, providing a FunctionLibraryRuntime
434   // to use for kernel creation and execution. In particular, this method can
435   // accept a FunctionLibraryRuntimeOverlay that overlays a different
436   // FunctionLibraryDefinition.
437   Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
438                       FunctionLibraryRuntime* flr, OpKernel** kernel);
439   Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs,
440                            const FunctionLibraryDefinition* lib_def,
441                            std::unique_ptr<FunctionBody>* fbody);
442   Status CreateItem(Item** item);
443   Status GetOrCreateItem(LocalHandle local_handle, Item** item);
444   Status InstantiateSymbolicGradient(const NameAttrList& func,
445                                      const FunctionLibraryDefinition* lib_def,
446                                      std::unique_ptr<FunctionBody>* g_body);
447   bool IsLocalTarget(const InstantiateOptions& options) const;
448   AttrValueMap FixAttrs(const AttrSlice& attrs);
449   void RunRemote(const Options& opts, Handle handle,
450                  gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
451                  Item* item, DoneCallback done);
452 
453   // TODO(fishx): Avoid using std::unique_ptr for PrivateIntraProcessRendezvous,
454   // since it will allocate the object on heap.
455   Status PrepareRunSync(
456       Handle handle, Options* run_opts, Item** out_item,
457       std::unique_ptr<PrivateIntraProcessRendezvous>* out_rendezvous);
458 
459   void ExecutorArgsFromOptions(const FunctionLibraryRuntime::Options& run_opts,
460                                CallFrameInterface* frame,
461                                Executor::Args* exec_args);
462 
463   TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
464 };
465 
FunctionLibraryRuntimeImpl(const DeviceMgr * dmgr,Env * env,const ConfigProto * config,Device * device,int graph_def_version,const FunctionLibraryDefinition * lib_def,thread::ThreadPool * default_thread_pool,const OptimizerOptions & optimizer_options,const SessionMetadata * session_metadata,ProcessFunctionLibraryRuntime * parent)466 FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
467     const DeviceMgr* dmgr, Env* env, const ConfigProto* config, Device* device,
468     int graph_def_version, const FunctionLibraryDefinition* lib_def,
469     thread::ThreadPool* default_thread_pool,
470     const OptimizerOptions& optimizer_options,
471     const SessionMetadata* session_metadata,
472     ProcessFunctionLibraryRuntime* parent)
473     : device_mgr_(dmgr),
474       device_(device),
475       env_(env),
476       config_(config),
477       graph_def_version_(graph_def_version),
478       base_lib_def_(lib_def),
479       optimizer_(optimizer_options),
480       session_metadata_(session_metadata),
481       default_runner_(nullptr),
482       device_name_(device_ == nullptr
483                        ? ProcessFunctionLibraryRuntime::kDefaultFLRDevice
484                        : device_->name()),
485       next_handle_(0),
486       items_(std::make_unique<
487              absl::flat_hash_map<Handle, std::unique_ptr<Item>>>()),
488       function_handle_cache_(std::make_unique<FunctionHandleCache>(this)),
489       parent_(parent) {
490   get_func_sig_ = [this](const string& op, const OpDef** sig) {
491     return base_lib_def_->LookUpOpDef(op, sig);
492   };
493   create_kernel_ = [this](const std::shared_ptr<const NodeProperties>& props,
494                           OpKernel** kernel) {
495     return CreateKernel(props, kernel);
496   };
497   thread::ThreadPool* pool = nullptr;
498   if (device_ != nullptr) {
499     pool = device_->tensorflow_device_thread_pool();
500   }
501   if (pool == nullptr) {
502     pool = default_thread_pool;
503   }
504   if (pool != nullptr) {
505     default_runner_ = [pool](Executor::Args::Closure c) {
506       pool->Schedule(std::move(c));
507     };
508   }
509 }
510 
~FunctionLibraryRuntimeImpl()511 FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {
512   // Deleting the items_ list will delete all the function handles registered in
513   // this object. A function may contains a few sub-functions which have also
514   // been registered in this object. Deleting the parent function will call
515   // ReleaseHandle in this class again for each of the sub-functions. These
516   // circular calls may cause segfault since the items_ may have already been
517   // partially deleted when releasing handles of sub-functions. Explicitly
518   // release items_ here and check it in ReleaseHandle to avoid this.
519   items_.reset();
520 }
521 
522 // An asynchronous op kernel which executes an instantiated function
523 // defined in a library.
524 class CallOp : public AsyncOpKernel {
525  public:
CallOp(FunctionLibraryRuntime::Handle handle,OpKernelConstruction * ctx)526   CallOp(FunctionLibraryRuntime::Handle handle, OpKernelConstruction* ctx)
527       : AsyncOpKernel(ctx), handle_(handle) {}
528 
~CallOp()529   ~CallOp() override {
530     // TODO(iga): Release the cached handle_
531   }
532 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)533   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
534     FunctionLibraryRuntime* lib = ctx->function_library();
535     OP_REQUIRES_ASYNC(ctx, lib != nullptr,
536                       errors::Internal("No function library is provided."),
537                       done);
538     FunctionLibraryRuntime::Options opts;
539     opts.rendezvous = ctx->rendezvous();
540     opts.cancellation_manager = ctx->cancellation_manager();
541     opts.step_container = ctx->step_container();
542     opts.stats_collector = ctx->stats_collector();
543     opts.runner = ctx->runner();
544     opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
545     opts.collective_executor = ctx->collective_executor();
546     opts.stack_trace = ctx->stack_trace();
547     std::vector<Tensor> args;
548     args.reserve(ctx->num_inputs());
549     for (int i = 0; i < ctx->num_inputs(); ++i) {
550       args.push_back(ctx->input(i));
551     }
552     std::vector<Tensor>* rets = new std::vector<Tensor>;
553     lib->Run(opts, handle_, args, rets,
554              [ctx, done, rets](const Status& status) {
555                if (!status.ok()) {
556                  ctx->SetStatus(status);
557                } else {
558                  const int ret_size = static_cast<int>(rets->size());
559                  CHECK_EQ(ret_size, ctx->num_outputs());
560                  for (int i = 0; i < ret_size; ++i) {
561                    ctx->set_output(i, (*rets)[i]);
562                  }
563                }
564                delete rets;
565                done();
566              });
567   }
568 
569  private:
570   FunctionLibraryRuntime::Handle handle_;
571 
572   TF_DISALLOW_COPY_AND_ASSIGN(CallOp);
573 };
574 
GetFunctionBody(Handle h)575 const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) {
576   LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h);
577   if (local_handle == kInvalidLocalHandle) {
578     LOG(ERROR) << "Could not find Handle: " << h
579                << " on device: " << device_name_;
580     return nullptr;
581   }
582 
583   tf_shared_lock l(mu_);
584   auto iter = items_->find(local_handle);
585   CHECK(iter != items_->end());
586   return iter->second->func_graph;
587 }
588 
GetRetTypes(Handle h,DataTypeVector * ret_types)589 Status FunctionLibraryRuntimeImpl::GetRetTypes(Handle h,
590                                                DataTypeVector* ret_types) {
591   if (parent_->IsMultiDevice(h)) {
592     return parent_->GetRetTypes(h, ret_types);
593   }
594   LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h);
595   if (local_handle == kInvalidLocalHandle) {
596     return errors::InvalidArgument("Handle ", h, " not found.");
597   }
598   const FunctionBody* fbody = GetFunctionBody(h);
599   *ret_types = fbody->ret_types;
600   return OkStatus();
601 }
602 
CreateKernel(const std::shared_ptr<const NodeProperties> & props,OpKernel ** kernel)603 Status FunctionLibraryRuntimeImpl::CreateKernel(
604     const std::shared_ptr<const NodeProperties>& props, OpKernel** kernel) {
605   return CreateKernel(props, this, kernel);
606 }
607 
CreateKernel(const std::shared_ptr<const NodeProperties> & props,FunctionLibraryRuntime * flr,OpKernel ** kernel)608 Status FunctionLibraryRuntimeImpl::CreateKernel(
609     const std::shared_ptr<const NodeProperties>& props,
610     FunctionLibraryRuntime* flr, OpKernel** kernel) {
611   // If a custom kernel creator is given, try that.
612   Status s;
613   const CustomKernelCreator* custom_kernel_creator =
614       GetDefaultCustomKernelCreator();
615   if (custom_kernel_creator &&
616       custom_kernel_creator->CanCreateKernel(*flr, props)) {
617     std::unique_ptr<OpKernel> ret;
618     s = custom_kernel_creator->CreateKernel(flr, props, &ret);
619     if (s.ok()) {
620       *kernel = ret.release();
621     } else {
622       VLOG(2) << "Custom creator error: " << s;
623     }
624     return s;
625   }
626 
627   const FunctionLibraryDefinition* lib_def =
628       flr->GetFunctionLibraryDefinition();
629   if (lib_def->Find(props->node_def.op()) == nullptr) {
630     // A primitive operation. Creates the registered kernel.
631     return CreateNonCachedKernel(device_, flr, props, graph_def_version_,
632                                  kernel);
633   }
634 
635   // Try to instantiate this function for the func/attr. Maybe it's
636   // cached already.
637   InstantiateOptions options;
638   if (lib_def != base_lib_def_) {
639     options.lib_def = lib_def;
640   }
641   Handle handle;
642   TF_RETURN_IF_ERROR(Instantiate(props->node_def.op(),
643                                  AttrSlice(&props->node_def.attr()), options,
644                                  &handle));
645 
646   const FunctionBody* fbody = GetFunctionBody(handle);
647   CHECK_NOTNULL(fbody);
648 
649   // TODO(zhifengc): For now, we assume int32 and resources are always on host
650   // memory and other types are always on device memory. We should do type
651   // inference over function body to derive the correct input/output memory
652   // types.
653   MemoryTypeVector input_memory_types;
654   for (const auto& t : fbody->arg_types) {
655     input_memory_types.push_back(MTypeFromDType(t));
656   }
657   MemoryTypeVector output_memory_types;
658   for (const auto& t : fbody->ret_types) {
659     output_memory_types.push_back(MTypeFromDType(t));
660   }
661 
662   // Constructs a CallOp kernel for running the instantiated function.
663   auto device_type = DeviceType(device_->attributes().device_type());
664   auto new_props = std::make_shared<NodeProperties>(
665       &fbody->fdef.signature(), props->node_def, fbody->arg_types,
666       fbody->ret_types);
667   OpKernelConstruction construction(
668       device_type, device_, device_->GetAllocator(AllocatorAttributes()), flr,
669       device_->resource_manager(), props, input_memory_types,
670       output_memory_types, graph_def_version_, &s);
671   if (s.ok()) {
672     *kernel = new CallOp(handle, &construction);
673   }
674   return s;
675 }
676 
FunctionDefToBody(const FunctionDef & fdef,AttrSlice attrs,const FunctionLibraryDefinition * lib_def,std::unique_ptr<FunctionBody> * fbody)677 Status FunctionLibraryRuntimeImpl::FunctionDefToBody(
678     const FunctionDef& fdef, AttrSlice attrs,
679     const FunctionLibraryDefinition* lib_def,
680     std::unique_ptr<FunctionBody>* fbody) {
681   if (lib_def == base_lib_def_) {
682     return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig_, fbody);
683   } else {
684     auto get_func_sig = [lib_def](const string& op, const OpDef** sig) {
685       return lib_def->LookUpOpDef(op, sig);
686     };
687     return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody);
688   }
689 }
690 
InstantiateSymbolicGradient(const NameAttrList & func,const FunctionLibraryDefinition * lib_def,std::unique_ptr<FunctionBody> * g_body)691 Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
692     const NameAttrList& func, const FunctionLibraryDefinition* lib_def,
693     std::unique_ptr<FunctionBody>* g_body) {
694   const FunctionDef* fdef = lib_def->Find(func.name());
695   if (fdef == nullptr) {
696     // f is a primitive op.
697     gradient::Creator creator;
698     TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator));
699     if (creator == nullptr) {
700       return errors::InvalidArgument("No gradient is defined for ",
701                                      func.name());
702     }
703     FunctionDef grad_fdef;
704     // TODO(josh11b): Should filter out the attrs from func that aren't used
705     // by the gradient function.
706     TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
707     TF_RETURN_IF_ERROR(
708         FunctionDefToBody(grad_fdef, AttrSlice(&func.attr()), lib_def, g_body));
709   } else {
710     // f is a user-defined function.
711     InstantiateOptions options;
712     if (lib_def != base_lib_def_) {
713       options.lib_def = lib_def;
714     }
715     Handle f_handle;
716     TF_RETURN_IF_ERROR(
717         Instantiate(func.name(), AttrSlice(&func.attr()), options, &f_handle));
718     const FunctionBody* f_body = GetFunctionBody(f_handle);
719     CHECK_NOTNULL(f_body);
720     *g_body = SymbolicGradient(*f_body);
721   }
722   return OkStatus();
723 }
724 
IsLocalTarget(const InstantiateOptions & options) const725 bool FunctionLibraryRuntimeImpl::IsLocalTarget(
726     const InstantiateOptions& options) const {
727   if (device_ == nullptr) return true;
728   if (options.target.empty()) return true;
729   if (options.is_multi_device_function) return false;
730   Device* target_device;
731   if (!device_mgr_->LookupDevice(options.target, &target_device).ok()) {
732     VLOG(1) << "Not instantiating function in FLR because failed to "
733             << "find device " << options.target << " in device manager";
734     return false;
735   }
736   if (target_device != device_) {
737     VLOG(1) << "Not instantiating function in FLR because target device "
738             << options.target
739             << " is different from FLR's device: " << device_->DebugString();
740     return false;
741   }
742   return true;
743 }
744 
Instantiate(const string & function_name,AttrSlice attrs,const InstantiateOptions & options,Handle * handle)745 Status FunctionLibraryRuntimeImpl::Instantiate(
746     const string& function_name, AttrSlice attrs,
747     const InstantiateOptions& options, Handle* handle) {
748   if (!IsLocalTarget(options)) {
749     return parent_->Instantiate(function_name, attrs, options, handle);
750   }
751 
752   if (options.use_function_cache) {
753     InstantiateOptions options_copy(options);
754     options_copy.use_function_cache = false;
755     return function_handle_cache_->Instantiate(function_name, attrs,
756                                                options_copy, handle);
757   }
758 
759   // Since this is a local target, ensure that the local `device_name_` appears
760   // in the canonical key.
761   InstantiateOptions options_copy(options);
762   options_copy.target = device_name_;
763   const string key = Canonicalize(function_name, attrs, options_copy);
764 
765   {
766     mutex_lock l(mu_);
767     *handle = parent_->GetHandle(key);
768     if (*handle != kInvalidHandle) {
769       FunctionLibraryRuntime::LocalHandle handle_on_device =
770           parent_->GetHandleOnDevice(device_name_, *handle);
771       if (handle_on_device == kInvalidLocalHandle) {
772         return errors::Internal("LocalHandle not found for handle ", *handle,
773                                 ".");
774       }
775       auto item_handle = items_->find(handle_on_device);
776       if (item_handle == items_->end()) {
777         return errors::Internal("LocalHandle ", handle_on_device,
778                                 " for handle ", *handle,
779                                 " not found in items.");
780       }
781       ++item_handle->second->instantiation_counter;
782       return OkStatus();
783     }
784   }
785 
786   const FunctionLibraryDefinition* lib_def =
787       options.lib_def ? options.lib_def : base_lib_def_;
788   std::unique_ptr<FunctionBody> fbody;
789   if (function_name == kGradientOp) {
790     const AttrValue* f = attrs.Find(kFuncAttr);
791     if (f == nullptr) {
792       return errors::InvalidArgument("SymbolicGradient is missing attr: f");
793     }
794     const auto& func = f->func();
795     if (func.name() == kGradientOp) {
796       return errors::InvalidArgument("Can't take gradient of SymbolicGradient");
797     }
798     const string grad = lib_def->FindGradient(func.name());
799     if (!grad.empty()) {
800       return Instantiate(grad, AttrSlice(&func.attr()), options, handle);
801     }
802     TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, lib_def, &fbody));
803   } else {
804     const FunctionDef* fdef = lib_def->Find(function_name);
805     if (fdef == nullptr) {
806       return errors::NotFound("Function ", function_name, " is not defined.");
807     }
808     TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, lib_def, &fbody));
809   }
810 
811   LocalHandle local_handle;
812   {
813     mutex_lock l(mu_);
814     *handle = parent_->GetHandle(key);
815     if (*handle != kInvalidHandle) {
816       local_handle = parent_->GetHandleOnDevice(device_name_, *handle);
817       ++(*items_)[local_handle]->instantiation_counter;
818     } else {
819       *handle = parent_->AddHandle(key, device_name_, next_handle_);
820       Item* item = new Item;
821       item->func_graph = fbody.release();
822       item->instantiation_counter = 1;
823       item->executor_type = ExecutorType(options, attrs);
824       item->allow_small_function_optimizations =
825           options.allow_small_function_optimizations;
826       item->allow_control_flow_sync_execution =
827           options.allow_control_flow_sync_execution;
828       if (options.lib_def) {
829         item->overlay_flr =
830             new FunctionLibraryRuntimeOverlay(this, options.lib_def);
831       }
832       local_handle = next_handle_++;
833       items_->emplace(local_handle, std::unique_ptr<Item>(item));
834     }
835   }
836 
837   if (options.create_kernels_eagerly) {
838     Item* item;
839     TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, &item));
840   }
841 
842   return OkStatus();
843 }
844 
ReleaseHandle(Handle handle)845 Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) {
846   LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
847   if (h == kInvalidLocalHandle) {
848     return parent_->ReleaseHandle(handle);
849   }
850   std::unique_ptr<Item> item_to_delete;
851   Status parent_status;
852   {
853     mutex_lock l(mu_);
854     // Return directly if all items has already been released.
855     if (items_ == nullptr) return OkStatus();
856 
857     auto it = items_->find(h);
858     if (it == items_->end()) {
859       return errors::Internal(
860           "Inconsistent FunctionLibraryRuntime. Expected to find an item for "
861           "handle ",
862           h, " but found none");
863     }
864     std::unique_ptr<Item>& item = it->second;
865     --item->instantiation_counter;
866     if (item->instantiation_counter == 0) {
867       // We don't simply erase h's item because that would trigger
868       // item destruction while holding mu_. Item destruction can
869       // trigger graph destruction. If the graph contains kernels like
870       // CallOp or PartitionCallOp, their destructors will release cached
871       // function handles, resulting in deadlock here.
872       item_to_delete = std::move(item);
873       items_->erase(h);
874       parent_status = parent_->RemoveHandle(handle);
875     }
876   }
877   return parent_status;
878 }
879 
880 namespace {
881 
882 // Removes all stateless nodes that do not contribute to a return
883 // value from the function body. Unlike `RemoveDeadNodes()`, which is
884 // triggered by `OptimizerOptions.do_function_inlining`, this pass
885 // ignores the SINK node, from which (by definition) all nodes are
886 // reverse reachable, and preserves all nodes that are reachable from
887 // control output nodes.
888 //
889 // TODO(ezhulenev, skyewm): Function body should not have special treatment of
890 // stateful ops, graph should encode nodes that must execute with `control_ret`
891 // and `control_output`.
PruneFunctionBody(const FunctionDef & fdef,Graph * g)892 void PruneFunctionBody(const FunctionDef& fdef, Graph* g) {
893   VLOG(2) << "Pruning function body: function_name=" << fdef.signature().name();
894 
895   // `control_ret` nodes must be always executed.
896   std::unordered_set<StringPiece, StringPieceHasher> control_ret_nodes;
897   for (const auto& control_ret : fdef.control_ret()) {
898     control_ret_nodes.insert(control_ret.second);
899   }
900 
901   std::unordered_set<const Node*> nodes;
902   for (auto n : g->nodes()) {
903     // NOTE(mrry): "_Retval" nodes are stateful, and so will be added
904     // to the seed set of `nodes`. "_Arg" nodes are also stateful, but we
905     // specifically exclude them as seeds, to avoid unconditionally executing
906     // unused argument nodes (e.g. in a function like `lambda x, y: y`).
907     // TODO(mrry): Investigate whether the `n->IsControlFlow()` test is
908     // still needed. It would be preferable to prune entire loops and/or
909     // conditionals if they are not used in the graph.
910     if (n->IsControlFlow() ||
911         (n->op_def().is_stateful() && n->type_string() != kArgOp) ||
912         (control_ret_nodes.find(n->name()) != control_ret_nodes.end())) {
913       nodes.insert(n);
914     }
915   }
916   bool changed = PruneForReverseReachability(g, std::move(nodes));
917   if (changed) {
918     FixupSourceAndSinkEdges(g);
919   }
920 }
921 
922 }  // namespace
923 
CreateItem(Item ** item)924 Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) {
925   const FunctionBody* fbody;
926   FunctionLibraryRuntime* flr;
927   string executor_type;
928   {
929     tf_shared_lock l(mu_);
930     fbody = (*item)->func_graph;
931     flr = (*item)->overlay_flr
932               ? static_cast<FunctionLibraryRuntime*>((*item)->overlay_flr)
933               : static_cast<FunctionLibraryRuntime*>(this);
934     executor_type = (*item)->executor_type;
935   }
936   const FunctionLibraryDefinition* lib_def =
937       flr->GetFunctionLibraryDefinition();
938   auto g = std::make_unique<Graph>(lib_def);
939   CopyGraph(*fbody->graph, g.get());
940 
941   PruneFunctionBody(fbody->fdef, g.get());
942   optimizer_.Optimize(this, env(), device(), &g, GraphOptimizer::Options());
943   TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device()->device_type()),
944                                        device()->name(), g.get()));
945 
946   // Creates an executor based on the g. This must be done without
947   // holding mu_ because create_kernel_ calls back into the library.
948   LocalExecutorParams params;
949   params.device = device_;
950   params.function_library = flr;
951   params.allow_control_flow_sync_execution =
952       (*item)->allow_control_flow_sync_execution;
953   if (flr == this) {
954     params.create_kernel = create_kernel_;
955   } else {
956     params.create_kernel =
957         [this, flr](const std::shared_ptr<const NodeProperties>& props,
958                     OpKernel** kernel) {
959           return CreateKernel(props, flr, kernel);
960         };
961   }
962   params.delete_kernel = [](OpKernel* kernel) {
963     DeleteNonCachedKernel(kernel);
964   };
965   params.session_metadata = session_metadata_;
966   std::unique_ptr<Executor> exec;
967 
968   // When the instantiation options request small function optimizations, all
969   // graphs which are safe for synchronous execution will set this flag to true:
970   if ((*item)->allow_small_function_optimizations && executor_type.empty()) {
971     executor_type = "SINGLE_THREADED_EXECUTOR";
972   }
973 
974   metrics::IncrementTestCounter("flr_executor",
975                                 (executor_type == "SINGLE_THREADED_EXECUTOR")
976                                     ? "single_threaded"
977                                     : "default");
978 
979   TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, *g, &exec));
980   {
981     // Guard item since it is already inserted in items_.
982     mutex_lock l(mu_);
983     if ((*item)->exec == nullptr) {
984       (*item)->graph = std::move(g);
985       (*item)->exec = exec.release();
986     }
987   }
988   return OkStatus();
989 }
990 
GetOrCreateItem(LocalHandle local_handle,Item ** item)991 Status FunctionLibraryRuntimeImpl::GetOrCreateItem(LocalHandle local_handle,
992                                                    Item** item) {
993   {
994     tf_shared_lock l(mu_);
995     auto iter = items_->find(local_handle);
996     if (iter == items_->end()) {
997       return errors::Internal("Local function handle ", local_handle,
998                               " is not valid. Likely an internal error.");
999     }
1000     *item = iter->second.get();
1001     if ((*item)->exec != nullptr) {
1002       return OkStatus();
1003     }
1004   }
1005   // NOTE: We need to call CreateItem out of mu_ because creating an
1006   // executor needs to call CreateKernel.
1007   return CreateItem(item);
1008 }
1009 
ExecutorArgsFromOptions(const FunctionLibraryRuntime::Options & run_opts,CallFrameInterface * frame,Executor::Args * exec_args)1010 void FunctionLibraryRuntimeImpl::ExecutorArgsFromOptions(
1011     const FunctionLibraryRuntime::Options& run_opts, CallFrameInterface* frame,
1012     Executor::Args* exec_args) {
1013   // Inherit the step_id from the caller.
1014   exec_args->step_id = run_opts.step_id;
1015   exec_args->rendezvous = run_opts.rendezvous;
1016   exec_args->stats_collector = run_opts.stats_collector;
1017   exec_args->cancellation_manager = run_opts.cancellation_manager;
1018   exec_args->step_container = run_opts.step_container;
1019   if (run_opts.runner) {
1020     exec_args->runner = *run_opts.runner;
1021   } else {
1022     exec_args->runner = default_runner_;
1023   }
1024   exec_args->collective_executor = run_opts.collective_executor;
1025   exec_args->call_frame = frame;
1026   exec_args->run_all_kernels_inline = run_opts.run_all_kernels_inline;
1027   exec_args->user_intra_op_threadpool = run_opts.user_intra_op_threadpool;
1028   exec_args->coordination_service_agent = run_opts.coordination_service_agent;
1029   exec_args->stack_trace = run_opts.stack_trace;
1030 }
1031 
RunRemote(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,Item * item,DoneCallback done)1032 void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
1033                                            gtl::ArraySlice<Tensor> args,
1034                                            std::vector<Tensor>* rets,
1035                                            Item* item, DoneCallback done) {
1036   string target_device = parent_->GetDeviceName(handle);
1037   string source_device = opts.source_device;
1038   RendezvousInterface* rendezvous = opts.rendezvous;
1039   DeviceContext* device_context;
1040   Status s = parent_->GetDeviceContext(target_device, &device_context);
1041   if (!s.ok()) {
1042     done(s);
1043     return;
1044   }
1045   int64_t src_incarnation, target_incarnation;
1046   s = parent_->GetDeviceIncarnation(source_device, &src_incarnation);
1047   s.Update(parent_->GetDeviceIncarnation(target_device, &target_incarnation));
1048   if (!s.ok()) {
1049     done(s);
1050     return;
1051   }
1052 
1053   const FunctionBody* fbody = GetFunctionBody(handle);
1054   FunctionCallFrame* frame =
1055       new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
1056   Executor::Args* exec_args = new Executor::Args;
1057   ExecutorArgsFromOptions(opts, frame, exec_args);
1058 
1059   std::vector<AllocatorAttributes> args_alloc_attrs, rets_alloc_attrs;
1060   args_alloc_attrs.reserve(fbody->arg_types.size());
1061   rets_alloc_attrs.reserve(fbody->ret_types.size());
1062   // Note: Functions assume that int32's are always on host memory.
1063   for (const auto& arg_type : fbody->arg_types) {
1064     AllocatorAttributes arg_alloc_attrs;
1065     if (MTypeFromDType(arg_type) == HOST_MEMORY) {
1066       arg_alloc_attrs.set_on_host(true);
1067     }
1068     args_alloc_attrs.push_back(arg_alloc_attrs);
1069   }
1070   for (const auto& ret_type : fbody->ret_types) {
1071     AllocatorAttributes ret_alloc_attrs;
1072     if (MTypeFromDType(ret_type) == HOST_MEMORY) {
1073       ret_alloc_attrs.set_on_host(true);
1074     }
1075     rets_alloc_attrs.push_back(ret_alloc_attrs);
1076   }
1077 
1078   bool allow_dead_tensors = opts.allow_dead_tensors;
1079 
1080   // The ProcFLR sends the arguments to the function from the source_device to
1081   // the target_device. So here we receive those arguments. Similarly, when the
1082   // computation is done and stored in *rets, we send the return values back
1083   // to the source_device (caller) so that the ProcFLR can receive them later.
1084   std::vector<Tensor>* remote_args = new std::vector<Tensor>;
1085   ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
1086       source_device, target_device, "arg_", src_incarnation, args.size(),
1087       device_context, args_alloc_attrs, rendezvous, remote_args,
1088       [frame, remote_args, item, source_device, target_device,
1089        target_incarnation, rendezvous, device_context, rets, done, exec_args,
1090        rets_alloc_attrs, allow_dead_tensors](const Status& status) {
1091         Status s = status;
1092         if (s.ok()) {
1093           s = frame->SetArgs(*remote_args);
1094         }
1095         if (!s.ok()) {
1096           delete frame;
1097           delete remote_args;
1098           delete exec_args;
1099           done(s);
1100           return;
1101         }
1102         item->exec->RunAsync(
1103             *exec_args,
1104             [frame, rets, done, source_device, target_device,
1105              target_incarnation, rendezvous, device_context, remote_args,
1106              rets_alloc_attrs, allow_dead_tensors](const Status& status) {
1107               Status s = status;
1108               if (s.ok()) {
1109                 s = frame->ConsumeRetvals(rets, allow_dead_tensors);
1110               }
1111               delete frame;
1112               if (!s.ok()) {
1113                 delete remote_args;
1114                 done(s);
1115                 return;
1116               }
1117               s = ProcessFunctionLibraryRuntime::SendTensors(
1118                   target_device, source_device, "ret_", target_incarnation,
1119                   *rets, device_context, rets_alloc_attrs, rendezvous);
1120               delete remote_args;
1121               done(s);
1122             });
1123         delete exec_args;
1124       });
1125 }
1126 
Run(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,DoneCallback done)1127 void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
1128                                      gtl::ArraySlice<Tensor> args,
1129                                      std::vector<Tensor>* rets,
1130                                      DoneCallback done) {
1131   if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
1132     done(errors::Cancelled("Function was cancelled before it was started"));
1133     return;
1134   }
1135   Options run_opts = opts;
1136   if (opts.create_rendezvous) {
1137     auto* rendezvous = new RefCountedIntraProcessRendezvous(device_mgr_);
1138     run_opts.rendezvous = rendezvous;
1139     run_opts.create_rendezvous = false;
1140     done = [done = std::move(done), rendezvous](const Status& status) mutable {
1141       rendezvous->Unref();
1142       done(status);
1143     };
1144   }
1145 
1146   LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
1147   if (local_handle == kInvalidLocalHandle) {
1148     parent_->Run(run_opts, handle, args, rets, done);
1149     return;
1150   }
1151 
1152   if (run_opts.runner == nullptr) {
1153     run_opts.runner = &default_runner_;
1154   }
1155   DCHECK(run_opts.runner != nullptr);
1156 
1157   Item* item = nullptr;
1158   Status s = GetOrCreateItem(local_handle, &item);
1159   if (!s.ok()) {
1160     done(s);
1161     return;
1162   }
1163 
1164   if (run_opts.remote_execution) {
1165     // NOTE(mrry): `RunRemote()` will set `exec_args->call_frame` for us.
1166     RunRemote(run_opts, handle, args, rets, item, std::move(done));
1167     return;
1168   }
1169 
1170   const FunctionBody* fbody = GetFunctionBody(handle);
1171   FunctionCallFrame* frame =
1172       new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
1173   s = frame->SetArgs(args);
1174   if (!s.ok()) {
1175     delete frame;
1176     done(s);
1177     return;
1178   }
1179 
1180   profiler::TraceMeProducer activity(
1181       // To TraceMeConsumers in ExecutorState::Process/Finish.
1182       [&opts] {
1183         return profiler::TraceMeEncode("FunctionRun",
1184                                        {{"id", opts.step_id}, {"_r", 1}});
1185       },
1186       profiler::ContextType::kTfExecutor, opts.step_id,
1187       profiler::TraceMeLevel::kInfo);
1188 
1189   Executor::Args exec_args;
1190   ExecutorArgsFromOptions(run_opts, frame, &exec_args);
1191 
1192   bool allow_dead_tensors = run_opts.allow_dead_tensors;
1193   item->exec->RunAsync(
1194       // Executor args
1195       exec_args,
1196       // Done callback.
1197       [frame, rets, done, allow_dead_tensors](const Status& status) {
1198         Status s = status;
1199         if (s.ok()) {
1200           s = frame->ConsumeRetvals(rets, allow_dead_tensors);
1201         }
1202         delete frame;
1203         done(s);
1204       });
1205 }
1206 
Run(const Options & opts,Handle handle,CallFrameInterface * frame,DoneCallback done)1207 void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
1208                                      CallFrameInterface* frame,
1209                                      DoneCallback done) {
1210   if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
1211     done(errors::Cancelled(""));
1212     return;
1213   }
1214 
1215   Options run_opts = opts;
1216   if (opts.create_rendezvous) {
1217     auto* rendezvous = new RefCountedIntraProcessRendezvous(device_mgr_);
1218     run_opts.rendezvous = rendezvous;
1219     run_opts.create_rendezvous = false;
1220     done = [done = std::move(done), rendezvous](const Status& status) mutable {
1221       rendezvous->Unref();
1222       done(status);
1223     };
1224   }
1225 
1226   LocalHandle local_handle = parent_->GetHandleOnDevice(
1227       device_name_, handle, /*include_multi_device=*/true);
1228   if (local_handle == kInvalidLocalHandle) {
1229     parent_->Run(run_opts, handle, frame, done);
1230     return;
1231   }
1232 
1233   if (opts.remote_execution) {
1234     // NOTE(mrry): This bit is only set for a local function when `parent_`
1235     // calls back into this class, and the current implementation of
1236     // `ProcessFunctionLibraryRuntime` currently always uses the vector-based
1237     // `args`/`rets` interface.
1238     done(errors::Unimplemented("Remote calling with CallFrameInterface"));
1239     return;
1240   }
1241 
1242   Item* item = nullptr;
1243   Status s = GetOrCreateItem(local_handle, &item);
1244   if (!s.ok()) {
1245     done(s);
1246     return;
1247   }
1248   if (run_opts.runner == nullptr) {
1249     run_opts.runner = &default_runner_;
1250   }
1251   DCHECK(run_opts.runner != nullptr);
1252 
1253   profiler::TraceMeProducer activity(
1254       // To TraceMeConsumers in ExecutorState::Process/Finish.
1255       [&opts] {
1256         return profiler::TraceMeEncode("FunctionRun",
1257                                        {{"id", opts.step_id}, {"_r", 1}});
1258       },
1259       profiler::ContextType::kTfExecutor, opts.step_id,
1260       profiler::TraceMeLevel::kInfo);
1261 
1262   Executor::Args exec_args;
1263   ExecutorArgsFromOptions(run_opts, frame, &exec_args);
1264   item->exec->RunAsync(exec_args, std::move(done));
1265 }
1266 
PrepareRunSync(Handle handle,Options * run_opts,Item ** out_item,std::unique_ptr<PrivateIntraProcessRendezvous> * out_rendezvous)1267 Status FunctionLibraryRuntimeImpl::PrepareRunSync(
1268     Handle handle, Options* run_opts, Item** out_item,
1269     std::unique_ptr<PrivateIntraProcessRendezvous>* out_rendezvous) {
1270   if (run_opts->cancellation_manager &&
1271       run_opts->cancellation_manager->IsCancelled()) {
1272     return errors::Cancelled("");
1273   }
1274 
1275   if (run_opts->remote_execution) {
1276     // NOTE(mrry): This bit is only set for a local function when `parent_`
1277     // calls back into this class, and the current implementation of
1278     // `ProcessFunctionLibraryRuntime` currently always uses the asynchronous
1279     // Run() method.
1280     return errors::Unimplemented("Remote calling with RunSync()");
1281   }
1282 
1283   if (run_opts->create_rendezvous) {
1284     *out_rendezvous =
1285         std::make_unique<PrivateIntraProcessRendezvous>(device_mgr_);
1286     run_opts->rendezvous = out_rendezvous->get();
1287     run_opts->create_rendezvous = false;
1288   }
1289 
1290   LocalHandle local_handle = parent_->GetHandleOnDevice(
1291       device_name_, handle, /*include_multi_device=*/true);
1292   if (local_handle == kInvalidLocalHandle) {
1293     *out_item = nullptr;
1294     return OkStatus();
1295   }
1296 
1297   TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, out_item));
1298 
1299   if (run_opts->runner == nullptr) {
1300     run_opts->runner = &default_runner_;
1301   }
1302   DCHECK(run_opts->runner != nullptr);
1303 
1304   return OkStatus();
1305 }
1306 
RunSync(Options opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets)1307 Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
1308                                            gtl::ArraySlice<Tensor> args,
1309                                            std::vector<Tensor>* rets) {
1310   Item* item = nullptr;
1311   std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous;
1312   TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous));
1313   if (item == nullptr) {
1314     return parent_->RunSync(opts, handle, args, rets);
1315   }
1316 
1317   Executor::Args exec_args;
1318   const FunctionBody* fbody = GetFunctionBody(handle);
1319   FunctionCallFrame frame(fbody->arg_types, fbody->ret_types);
1320   TF_RETURN_IF_ERROR(frame.SetArgs(args));
1321   ExecutorArgsFromOptions(opts, &frame, &exec_args);
1322 
1323   TF_RETURN_IF_ERROR(item->exec->Run(exec_args));
1324   return frame.ConsumeRetvals(rets, opts.allow_dead_tensors);
1325 }
1326 
RunSync(Options opts,Handle handle,CallFrameInterface * call_frame)1327 Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
1328                                            CallFrameInterface* call_frame) {
1329   Item* item = nullptr;
1330   std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous;
1331   TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous));
1332   if (item == nullptr) {
1333     return parent_->RunSync(opts, handle, call_frame);
1334   }
1335 
1336   Executor::Args exec_args;
1337   ExecutorArgsFromOptions(opts, call_frame, &exec_args);
1338   return item->exec->Run(exec_args);
1339 }
1340 
IsStateful(const string & func) const1341 bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) const {
1342   const OpDef* op_def;
1343   const Status s = base_lib_def_->LookUpOpDef(func, &op_def);
1344   return s.ok() && op_def->is_stateful();
1345 }
1346 
DebugString(Handle handle)1347 string FunctionLibraryRuntimeImpl::DebugString(Handle handle) {
1348   Item* item = nullptr;
1349   LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
1350   Status s = GetOrCreateItem(local_handle, &item);
1351   if (s.ok()) {
1352     if (item->graph) {
1353       return tensorflow::DebugString(item->graph.get());
1354     } else {
1355       return tensorflow::DebugString(item->func_graph->graph);
1356     }
1357   } else {
1358     return s.ToString();
1359   }
1360 }
1361 
Clone(std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr,FunctionLibraryRuntime ** out_flr,bool skip_flib_def)1362 Status FunctionLibraryRuntimeImpl::Clone(
1363     std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
1364     std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
1365     FunctionLibraryRuntime** out_flr, bool skip_flib_def) {
1366   TF_RETURN_IF_ERROR(parent_->Clone(env_, graph_def_version_,
1367                                     optimizer_.options(), out_lib_def, out_pflr,
1368                                     skip_flib_def));
1369   *out_flr = (*out_pflr)->GetFLR(device_->name());
1370   if (*out_flr != nullptr) {
1371     return OkStatus();
1372   } else {
1373     return errors::Internal("Cloning FunctionLibraryRuntime failed.");
1374   }
1375 }
1376 
1377 namespace {
1378 
1379 struct CustomCreatorSingleton {
1380   mutex mu;
1381   std::unique_ptr<CustomKernelCreator> custom_creator = nullptr;
1382 
Settensorflow::__anona4f2dcd81011::CustomCreatorSingleton1383   void Set(CustomKernelCreator* cb) {
1384     mutex_lock l(mu);
1385     custom_creator.reset(cb);
1386   }
1387 
Gettensorflow::__anona4f2dcd81011::CustomCreatorSingleton1388   CustomKernelCreator* Get() {
1389     mutex_lock l(mu);
1390     return custom_creator.get();
1391   }
1392 };
1393 
GetCustomCreatorSingleton()1394 CustomCreatorSingleton* GetCustomCreatorSingleton() {
1395   static CustomCreatorSingleton* ccs = new CustomCreatorSingleton;
1396   return ccs;
1397 }
1398 
1399 }  // namespace
1400 
GetDefaultCustomKernelCreator()1401 const CustomKernelCreator* GetDefaultCustomKernelCreator() {
1402   return GetCustomCreatorSingleton()->Get();
1403 }
1404 
RegisterDefaultCustomKernelCreator(CustomKernelCreator * c)1405 void RegisterDefaultCustomKernelCreator(CustomKernelCreator* c) {
1406   GetCustomCreatorSingleton()->Set(c);
1407 }
1408 
NewFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,const ConfigProto * config,Device * device,int graph_def_version,const FunctionLibraryDefinition * lib_def,thread::ThreadPool * thread_pool,const OptimizerOptions & optimizer_options,const SessionMetadata * session_metadata,ProcessFunctionLibraryRuntime * parent)1409 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
1410     const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
1411     Device* device, int graph_def_version,
1412     const FunctionLibraryDefinition* lib_def, thread::ThreadPool* thread_pool,
1413     const OptimizerOptions& optimizer_options,
1414     const SessionMetadata* session_metadata,
1415     ProcessFunctionLibraryRuntime* parent) {
1416   return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl(
1417       device_mgr, env, config, device, graph_def_version, lib_def, thread_pool,
1418       optimizer_options, session_metadata, parent));
1419 }
1420 
1421 class SymbolicGradientHelper {
1422  public:
SymbolicGradientHelper(const FunctionBody & f)1423   explicit SymbolicGradientHelper(const FunctionBody& f) : fbody_(&f) {}
1424   ~SymbolicGradientHelper() = default;
1425 
1426   std::unique_ptr<FunctionBody> Compute();
1427 
1428  private:
1429   const FunctionBody* fbody_;
1430 
1431   // Makes a copy of fbody_ in gbody.
1432   void Copy(FunctionBody* gbody);
1433 
1434   TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientHelper);
1435 };
1436 
Copy(FunctionBody * gbody)1437 void SymbolicGradientHelper::Copy(FunctionBody* gbody) {
1438   const Graph& src = *(fbody_->graph);
1439   gbody->graph = new Graph(src.op_registry());
1440   Graph* dst = gbody->graph;
1441 
1442   std::vector<Node*> node_map(src.num_node_ids());
1443 
1444   // Copy just the fdef attributes (copy '_noinline' and other similar flags to
1445   // the gradient function body).
1446   *(gbody->fdef.mutable_attr()) = fbody_->fdef.attr();
1447 
1448   // Copy the nodes.
1449   node_map[src.source_node()->id()] = dst->source_node();
1450   node_map[src.sink_node()->id()] = dst->sink_node();
1451   for (Node* n : src.op_nodes()) {
1452     node_map[n->id()] = dst->CopyNode(n);
1453   }
1454 
1455   // Copy the edges.
1456   for (const Edge* e : src.edges()) {
1457     Node* src_copy = node_map[e->src()->id()];
1458     Node* dst_copy = node_map[e->dst()->id()];
1459     dst->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
1460   }
1461 
1462   // Save inputs in copied graph.
1463   CHECK_EQ(fbody_->arg_types.size(), fbody_->arg_nodes.size());
1464   gbody->arg_types = fbody_->arg_types;
1465   for (std::size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
1466     gbody->arg_nodes.push_back(node_map[fbody_->arg_nodes[i]->id()]);
1467   }
1468 
1469   // Save outputs in copied graph.
1470   CHECK_EQ(fbody_->ret_types.size(), fbody_->ret_nodes.size());
1471   gbody->ret_types = fbody_->ret_types;
1472   for (std::size_t i = 0; i < fbody_->ret_nodes.size(); ++i) {
1473     gbody->ret_nodes.push_back(node_map[fbody_->ret_nodes[i]->id()]);
1474   }
1475 }
1476 
Compute()1477 std::unique_ptr<FunctionBody> SymbolicGradientHelper::Compute() {
1478   FunctionBody* gbody = new FunctionBody;
1479   Copy(gbody);  // copy fbody_ into gbody.
1480 
1481   Graph* g = gbody->graph;
1482 
1483   const int num_y = static_cast<int>(gbody->ret_nodes.size());
1484 
1485   // Populate 'y_node_outputs_' with node function body outputs.
1486   // Populate 'y_grad_nodes' with initial gradient nodes for each return node
1487   // of the original function body (these will be 'arg' nodes in the function
1488   // gradient body).
1489   std::vector<NodeOut> y_node_outputs;
1490   y_node_outputs.reserve(num_y);
1491   std::vector<NodeOut> y_grad_node_outputs;
1492   y_grad_node_outputs.reserve(num_y);
1493   for (int i = 0; i < num_y; ++i) {
1494     Node* y = gbody->ret_nodes[i];
1495     y_node_outputs.push_back({y, 0});
1496     DCHECK_EQ(y->type_string(), kRetOp);
1497     const DataType dtype = y->input_type(0);
1498     const int index = static_cast<int>(gbody->arg_nodes.size());
1499     Node* dy = AddArg(g, dtype, index);
1500     gbody->arg_types.push_back(dtype);
1501     gbody->arg_nodes.push_back(dy);
1502     y_grad_node_outputs.push_back({dy, 0});
1503   }
1504 
1505   // Populate 'x_nodes' with function args (excluding 'y_grad_node_outputs').
1506   const size_t num_x = fbody_->arg_nodes.size();
1507   std::vector<NodeOut> x_node_outputs;
1508   x_node_outputs.reserve(num_x);
1509   for (size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
1510     x_node_outputs.push_back({gbody->arg_nodes[i], 0});
1511   }
1512 
1513   // Call AddSymbolicGradients which will add nodes to graph 'g' that
1514   // compute the function gradient (adding an entry in 'x_grad_node_outputs'
1515   // for each node in 'x_node_outputs').
1516   std::vector<NodeOut> x_grad_node_outputs;
1517   TF_CHECK_OK(AddSymbolicGradients(y_node_outputs, x_node_outputs,
1518                                    y_grad_node_outputs, &x_grad_node_outputs,
1519                                    g));
1520 
1521   // Remove the old return nodes from the function body.
1522   for (Node* n : gbody->ret_nodes) {
1523     g->RemoveNode(n);
1524   }
1525   gbody->ret_types = fbody_->arg_types;
1526   // TODO(apassos): use the right dtype for gradients of  resource variables
1527   for (int i = 0; i < gbody->ret_types.size(); ++i) {
1528     if (gbody->ret_types[i] == DT_RESOURCE) {
1529       gbody->ret_types[i] = DT_FLOAT;
1530     }
1531   }
1532   gbody->ret_nodes.clear();
1533   // Add new return nodes to the function gradient body for each node
1534   // in 'x_grad_nodes'.
1535   const int arg_types_size = static_cast<int>(fbody_->arg_types.size());
1536   for (int i = 0; i < arg_types_size; ++i) {
1537     Endpoint grad = {x_grad_node_outputs[i].node, x_grad_node_outputs[i].index};
1538     Node* ret = AddRet(g, grad, i);
1539     gbody->ret_nodes.push_back(ret);
1540   }
1541 
1542   return std::unique_ptr<FunctionBody>(gbody);
1543 }
1544 
SymbolicGradient(const FunctionBody & f)1545 std::unique_ptr<FunctionBody> SymbolicGradient(const FunctionBody& f) {
1546   return SymbolicGradientHelper(f).Compute();
1547 }
1548 
1549 }  // end namespace tensorflow
1550