1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/cluster_function_library_runtime.h"
16 
17 #include <map>
18 
19 #include "tensorflow/core/common_runtime/function.h"
20 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
21 #include "tensorflow/core/distributed_runtime/worker_session.h"
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/framework/graph_def_util.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/node_def_builder.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 #include "tensorflow/core/graph/node_builder.h"
28 #include "tensorflow/core/lib/gtl/cleanup.h"
29 #include "tensorflow/core/lib/random/random.h"
30 #include "tensorflow/core/protobuf/named_tensor.pb.h"
31 #include "tensorflow/core/protobuf/worker.pb.h"
32 
33 namespace tensorflow {
34 
35 /* static */
ConstructFunctionGraph(const OpDef & sig,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,const FunctionLibraryDefinition & flib_def,GraphDef * gdef,std::vector<string> * send_keys,std::vector<string> * recv_keys)36 Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
37     const OpDef& sig, AttrSlice attrs,
38     const FunctionLibraryRuntime::InstantiateOptions& options,
39     const FunctionLibraryDefinition& flib_def, GraphDef* gdef,
40     std::vector<string>* send_keys, std::vector<string>* recv_keys) {
41   const string& target = options.target;
42   const string& func_name = sig.name();
43   const FunctionDef* func_def = flib_def.Find(sig.name());
44   if (func_def == nullptr) {
45     return errors::InvalidArgument("Function ", func_name,
46                                    " not found in flib_def.");
47   }
48 
49   // Build a smaller flib_def containing only the functions used by the given
50   // function, plus that function itself.
51   FunctionLibraryDefinition pruned_flib_def =
52       flib_def.ReachableDefinitions(*func_def);
53   TF_RETURN_IF_ERROR(pruned_flib_def.CopyFunctionDefFrom(func_name, flib_def));
54 
55   Graph g(pruned_flib_def);
56 
57   std::vector<Node*> input_nodes;
58   input_nodes.reserve(sig.input_arg_size());
59 
60   // Construct recv nodes for each input argument.
61   int i = 0;
62   for (const auto& in : sig.input_arg()) {
63     // Resolve the input type.
64     bool is_type_list;
65     DataTypeVector dtypes;
66     TF_RETURN_IF_ERROR(ArgNumType(attrs, in, &is_type_list, &dtypes));
67     // TODO(rohanj): Handle list and variadic number of attrs. Here and below.
68     if (is_type_list || dtypes.size() > 1) {
69       return errors::Unimplemented("Input arg: ", in.name(),
70                                    " has a list type or variadic number of "
71                                    "attrs. Currently unsupported.");
72     }
73 
74     auto input_node_builder =
75         NodeDefBuilder(strings::StrCat("_recv_", in.name(), "_", i), "_Recv")
76             .Attr("tensor_type", dtypes[0])
77             .Attr("tensor_name", in.name())
78             .Attr("send_device", target)
79             .Attr("recv_device", target)
80             .Attr("send_device_incarnation", 1)
81             .Attr("client_terminated", true)
82             .Device(target);
83 
84     Node* input_node;
85     TF_RETURN_IF_ERROR(
86         NodeBuilder(input_node_builder).Finalize(&g, &input_node));
87     input_nodes.push_back(input_node);
88 
89     // src_incarnation = 1 works because the transfer is across the same device.
90     // TODO(rohanj): Find the src_incarnation for the remote device and set it.
91     const string& key = Rendezvous::CreateKey(
92         target, 1 /* src_incarnation */, target, in.name(), FrameAndIter(0, 0));
93     send_keys->push_back(key);
94     ++i;
95   }
96 
97   NodeDef function_node_def;
98   function_node_def.set_name(func_name);
99   function_node_def.set_op(func_name);
100   i = 0;
101   function_node_def.set_device(target);
102   for (const auto& p : attrs) {
103     (*function_node_def.mutable_attr())[p.first] = p.second;
104   }
105   TF_ASSIGN_OR_RETURN(Node * function_node,
106                       g.AddNode(std::move(function_node_def)));
107   for (size_t i = 0; i < input_nodes.size(); ++i) {
108     g.AddEdge(input_nodes[i], 0, function_node, i);
109   }
110 
111   // Construct output nodes for each output.
112   i = 0;
113   for (const auto& out : sig.output_arg()) {
114     // Resolve the output type.
115     bool is_type_list;
116     DataTypeVector dtypes;
117     TF_RETURN_IF_ERROR(ArgNumType(attrs, out, &is_type_list, &dtypes));
118     // TODO(rohanj): Handle list and variadic number of attrs. Here and below.
119     if (is_type_list || dtypes.size() > 1) {
120       return errors::Unimplemented("Output arg: ", out.name(),
121                                    " has a list type or variadic number of "
122                                    "attrs. Currently unsupported.");
123     }
124 
125     auto output_node_builder =
126         NodeDefBuilder(strings::StrCat("_send_", out.name(), "_", i), "_Send")
127             .Input(func_name, i, dtypes[0])
128             .Attr("tensor_name", out.name())
129             .Attr("send_device", target)
130             .Attr("recv_device", target)
131             .Attr("send_device_incarnation", 1)
132             .Attr("client_terminated", true)
133             .Device(target);
134 
135     Node* output_node;
136     TF_RETURN_IF_ERROR(
137         NodeBuilder(output_node_builder).Finalize(&g, &output_node));
138 
139     g.AddEdge(function_node, i, output_node, 0);
140 
141     const string& key =
142         Rendezvous::CreateKey(target, 1 /* src_incarnation */, target,
143                               out.name(), FrameAndIter(0, 0));
144     recv_keys->push_back(key);
145     ++i;
146   }
147 
148   // Inline function node into the graph.
149   InlineFunctionBodyOptions inline_options;
150   inline_options.inlined_function_body_placer =
151       InlinedFunctionBodyPlacer::SingleDevice();
152   // When the remote call is a partition of a multi-device function, and the
153   // Send/Recv nodes depend on the frame names in the original graph, we must
154   // retain the original frame names. Since the graph contains a single function
155   // call, we do not need to add a unique prefix to frame names inside the
156   // inlined graph.
157   inline_options.uniquify_frame_names = false;
158   std::unique_ptr<FunctionBody> function_body;
159   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*func_def, attrs, &pruned_flib_def,
160                                              &function_body));
161   TF_RETURN_IF_ERROR(InlineFunctionBody(pruned_flib_def, &g, function_node,
162                                         function_body.get(), inline_options));
163 
164   g.ToGraphDef(gdef);
165 
166   // Since we have inlined `function_node`, we can prune its function definition
167   // from the library.
168   *(gdef->mutable_library()) = flib_def.ReachableDefinitions(*gdef).ToProto();
169 
170   return OkStatus();
171 }
172 
~ClusterFunctionLibraryRuntime()173 ClusterFunctionLibraryRuntime::~ClusterFunctionLibraryRuntime() {
174   for (auto& function_data : function_data_) {
175     worker_session_->worker_cache()->ReleaseWorker(function_data.target,
176                                                    function_data.wi);
177   }
178 }
179 
Instantiate(const string & function_name,const FunctionLibraryDefinition & lib_def,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::LocalHandle * handle,FunctionLibraryRuntime::DoneCallback done)180 void ClusterFunctionLibraryRuntime::Instantiate(
181     const string& function_name, const FunctionLibraryDefinition& lib_def,
182     AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options,
183     FunctionLibraryRuntime::LocalHandle* handle,
184     FunctionLibraryRuntime::DoneCallback done) {
185   auto target = options.target;
186   VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << target
187           << " (this: " << this << ")";
188   std::shared_ptr<WorkerCacheInterface> worker_cache =
189       worker_session_->GetSharedWorkerCache();
190   WorkerInterface* wi = worker_cache->GetOrCreateWorker(target);
191 
192   if (wi == nullptr) {
193     std::vector<string> workers;
194     worker_session_->worker_cache()->ListWorkers(&workers);
195     done(errors::InvalidArgument(
196         "Could not find worker with target: ", target,
197         " Available workers: ", absl::StrJoin(workers, ", ")));
198     return;
199   }
200 
201   // Make RPC and obtain a graph handle.
202   GraphDef gdef;
203   auto* send_keys = new std::vector<string>;
204   auto* recv_keys = new std::vector<string>;
205   auto construct_graph_fn = [&](const FunctionLibraryDefinition* lib_def) {
206     const FunctionDef* fdef = lib_def->Find(function_name);
207     const OpDef& sig = fdef->signature();
208     TF_RETURN_IF_ERROR(ConstructFunctionGraph(sig, attrs, options, *lib_def,
209                                               &gdef, send_keys, recv_keys));
210     return OkStatus();
211   };
212   Status s;
213   if (options.lib_def) {
214     s = construct_graph_fn(options.lib_def);
215   } else {
216     s = construct_graph_fn(&lib_def);
217   }
218   if (!s.ok()) {
219     done(s);
220     return;
221   }
222 
223   auto* req = new RegisterGraphRequest;
224   req->set_session_handle(worker_session_->session_name());
225   req->set_create_worker_session_called(create_worker_session_called_);
226   *req->mutable_graph_def() = std::move(gdef);
227   StripDefaultAttributes(*OpRegistry::Global(),
228                          req->mutable_graph_def()->mutable_node());
229   req->mutable_graph_options()
230       ->mutable_optimizer_options()
231       ->set_do_function_inlining(true);
232   auto* resp = new RegisterGraphResponse;
233 
234   wi->RegisterGraphAsync(
235       req, resp,
236       [this, handle, req, resp, worker_cache, wi, function_name, target,
237        send_keys, recv_keys, done](const Status& status) {
238         if (status.ok()) {
239           mutex_lock l(mu_);
240           *handle = function_data_.size();
241           function_data_.push_back(FunctionData(resp->graph_handle(), target,
242                                                 worker_cache, wi, *send_keys,
243                                                 *recv_keys));
244           VLOG(1) << "CFLR::Instantiate: [Success] " << function_name << " on "
245                   << target << " (this: " << this << ")"
246                   << " with handle: " << *handle;
247         }
248         done(status);
249         delete recv_keys;
250         delete send_keys;
251         delete req;
252         delete resp;
253       });
254 }
255 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::LocalHandle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done)256 void ClusterFunctionLibraryRuntime::Run(
257     const FunctionLibraryRuntime::Options& opts,
258     FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice<Tensor> args,
259     std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
260   FunctionData* function_data = nullptr;
261   {
262     mutex_lock l(mu_);
263     CHECK_LE(handle, function_data_.size());
264     function_data = &function_data_[handle];
265   }
266 
267   WorkerInterface* wi = function_data->wi;
268 
269   if (wi == nullptr) {
270     done(errors::Internal("Could not find worker"));
271     return;
272   }
273 
274   RunGraphRequest* req = new RunGraphRequest;
275   req->set_session_handle(worker_session_->session_name());
276   req->set_create_worker_session_called(create_worker_session_called_);
277   req->set_graph_handle(function_data->graph_handle);
278   req->set_step_id(opts.step_id);
279   int i = 0;
280   for (const auto& send_key : function_data->send_keys) {
281     NamedTensorProto* send = req->add_send();
282     send->set_name(send_key);
283     args[i].AsProtoTensorContent(send->mutable_tensor());
284     i++;
285   }
286   const std::vector<string>& recv_keys = function_data->recv_keys;
287   for (const auto& recv_key : recv_keys) {
288     req->add_recv_key(recv_key);
289   }
290 
291   RunGraphResponse* resp = new RunGraphResponse();
292   CallOptions* call_options = new CallOptions();
293   wi->RunGraphAsync(
294       call_options, req, resp,
295       [call_options, req, resp, rets, recv_keys, done](const Status& status) {
296         Status* local_status = new Status(status);
297         auto cleanup =
298             gtl::MakeCleanup([call_options, req, resp, local_status, done] {
299               done(*local_status);
300               delete call_options;
301               delete req;
302               delete resp;
303               delete local_status;
304             });
305         if (!local_status->ok()) {
306           return;
307         }
308         std::map<string, TensorProto*> mapped_recvs;
309         for (auto& recv : *resp->mutable_recv()) {
310           mapped_recvs[recv.name()] = recv.mutable_tensor();
311         }
312 
313         for (const auto& recv_key : recv_keys) {
314           TensorProto* tp = mapped_recvs[recv_key];
315           if (tp == nullptr) {
316             local_status->Update(
317                 errors::Internal("Could not find key: ", recv_key));
318             return;
319           }
320           Tensor t;
321           if (t.FromProto(*tp)) {
322             rets->push_back(t);
323           } else {
324             local_status->Update(errors::Internal(
325                 "Could not convert tensor proto: ", tp->DebugString()));
326             return;
327           }
328         }
329       });
330 }
331 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::LocalHandle handle,gtl::ArraySlice<FunctionArg> args,std::vector<FunctionRet> * rets,FunctionLibraryRuntime::DoneCallback done)332 void ClusterFunctionLibraryRuntime::Run(
333     const FunctionLibraryRuntime::Options& opts,
334     FunctionLibraryRuntime::LocalHandle handle,
335     gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets,
336     FunctionLibraryRuntime::DoneCallback done) {
337   std::vector<Tensor> tensors;
338   for (const auto& arg : args) {
339     if (arg.index() == 0) {
340       tensors.push_back(absl::get<Tensor>(arg));
341     } else {
342       done(
343           errors::Internal("ClusterFunctionLibraryRuntime doesn't support "
344                            "eager::RemoteTensorHandle."));
345       return;
346     }
347   }
348   std::vector<Tensor>* ret_tensors = new std::vector<Tensor>;
349   return Run(opts, handle, tensors, ret_tensors,
350              [rets, ret_tensors, done = std::move(done)](const Status& s) {
351                if (s.ok()) {
352                  for (const auto& t : *ret_tensors) {
353                    rets->push_back(t);
354                  }
355                }
356                delete ret_tensors;
357                done(s);
358              });
359 }
360 
CleanUp(uint64 step_id,FunctionLibraryRuntime::LocalHandle handle,FunctionLibraryRuntime::DoneCallback done)361 void ClusterFunctionLibraryRuntime::CleanUp(
362     uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
363     FunctionLibraryRuntime::DoneCallback done) {
364   FunctionData* function_data = nullptr;
365   {
366     mutex_lock l(mu_);
367     DCHECK_LE(handle, function_data_.size());
368     function_data = &function_data_[handle];
369   }
370 
371   WorkerInterface* wi = function_data->wi;
372 
373   if (wi == nullptr) {
374     done(errors::Internal("Could not find worker"));
375     return;
376   }
377   CleanupGraphRequest* cleanup_req = new CleanupGraphRequest;
378   cleanup_req->set_step_id(step_id);
379   CleanupGraphResponse* cleanup_resp = new CleanupGraphResponse;
380   wi->CleanupGraphAsync(
381       cleanup_req, cleanup_resp,
382       [cleanup_req, cleanup_resp, done](const Status& cleanup_status) {
383         done(cleanup_status);
384         delete cleanup_req;
385         delete cleanup_resp;
386       });
387 }
388 
389 }  // namespace tensorflow
390