1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
17
18 #include <unordered_map>
19
20 #include "tensorflow/core/common_runtime/session_factory.h"
21 #include "tensorflow/core/distributed_runtime/call_options.h"
22 #include "tensorflow/core/distributed_runtime/local_master.h"
23 #include "tensorflow/core/distributed_runtime/master_interface.h"
24 #include "tensorflow/core/distributed_runtime/request_id.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/tensor.pb.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/protobuf/master.pb.h"
34
35 namespace tensorflow {
36
37 const char* const kSchemePrefix = "grpc://";
38 const size_t kSchemePrefixLength = strlen(kSchemePrefix);
39
GrpcSession(const SessionOptions & options)40 GrpcSession::GrpcSession(const SessionOptions& options)
41 : options_(options), current_graph_version_(-1) {}
42
~GrpcSession()43 GrpcSession::~GrpcSession() {}
44
45 /* static */
Create(const SessionOptions & options,std::unique_ptr<GrpcSession> * out_session)46 Status GrpcSession::Create(const SessionOptions& options,
47 std::unique_ptr<GrpcSession>* out_session) {
48 std::unique_ptr<GrpcSession> session(new GrpcSession(options));
49 std::unique_ptr<MasterInterface> master;
50 // For testing, we enable the client to disable the use of the local
51 // master registry, so that the RPC stack is exercised.
52 if (!options.config.rpc_options().use_rpc_for_inprocess_master()) {
53 master = LocalMaster::Lookup(options.target);
54 }
55 if (!master) {
56 SharedGrpcChannelPtr master_channel;
57 TF_RETURN_IF_ERROR(
58 NewHostPortGrpcChannel(options.target.substr(kSchemePrefixLength),
59 &options.config.rpc_options(), &master_channel));
60 master.reset(NewGrpcMaster(master_channel));
61 } else {
62 session->is_local_ = true;
63 }
64 session->SetRemoteMaster(std::move(master));
65 *out_session = std::move(session);
66 return OkStatus();
67 }
68
69 namespace {
70 // Re-encodes constant represented in tensor proto into
71 // tensor_content, which is slightly better (less copies and lower peak
72 // memory usage) when used with rpc subsystems.
ReEncodeConsts(GraphDef * gdef)73 void ReEncodeConsts(GraphDef* gdef) {
74 for (NodeDef& ndef : *(gdef->mutable_node())) {
75 if (ndef.op() == "Const") {
76 TensorProto* proto = nullptr;
77 for (auto& attr : *ndef.mutable_attr()) {
78 if (attr.first == "value") {
79 proto = attr.second.mutable_tensor();
80 }
81 }
82 if (proto != nullptr && proto->tensor_content().empty() &&
83 proto->ByteSizeLong() > 64) {
84 // If the constant is encoded with repeated proto fields and
85 // it is moderate large, we re-encode it in tensor_content as
86 // a Cord. This is mildly helpful for reducing the peak memory
87 // usage on the server side where GraphDef/NodeDef are copied
88 // quite often.
89 Tensor parsed(proto->dtype());
90 if (parsed.FromProto(*proto)) {
91 parsed.AsProtoTensorContent(proto);
92 }
93 }
94 }
95 }
96 }
97 } // namespace
98
SetHandleAndGraphVersion(string handle,int64_t graph_version)99 void GrpcSession::SetHandleAndGraphVersion(string handle,
100 int64_t graph_version) {
101 mutex_lock l(mu_);
102 handle_ = std::move(handle);
103 current_graph_version_ = graph_version;
104 }
105
Handle(string * out_handle)106 Status GrpcSession::Handle(string* out_handle) {
107 mutex_lock l(mu_);
108 if (handle_.empty()) {
109 return errors::InvalidArgument("A session is not created yet....");
110 }
111 *out_handle = handle_;
112 return OkStatus();
113 }
114
CreateImpl(CallOptions * call_options,GraphDef graph)115 Status GrpcSession::CreateImpl(CallOptions* call_options, GraphDef graph) {
116 {
117 mutex_lock l(mu_);
118 if (!handle_.empty()) {
119 return errors::InvalidArgument("A session is alive.");
120 }
121 }
122 CreateSessionRequest req;
123 *req.mutable_config() = options_.config;
124 req.mutable_graph_def()->Swap(&graph);
125 req.set_target(options_.target);
126 ReEncodeConsts(req.mutable_graph_def());
127 CreateSessionResponse resp;
128 Status s = master_->CreateSession(call_options, &req, &resp);
129 if (s.ok()) {
130 SetHandleAndGraphVersion(resp.session_handle(), resp.graph_version());
131 }
132 return s;
133 }
134
Create(const GraphDef & graph)135 Status GrpcSession::Create(const GraphDef& graph) {
136 return Create(GraphDef(graph));
137 }
138
Create(const RunOptions & run_options,const GraphDef & graph)139 Status GrpcSession::Create(const RunOptions& run_options,
140 const GraphDef& graph) {
141 return Create(run_options, GraphDef(graph));
142 }
143
Create(GraphDef && graph)144 Status GrpcSession::Create(GraphDef&& graph) {
145 CallOptions call_options;
146 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
147 return CreateImpl(&call_options, std::move(graph));
148 }
149
Create(const RunOptions & run_options,GraphDef && graph)150 Status GrpcSession::Create(const RunOptions& run_options, GraphDef&& graph) {
151 CallOptions call_options;
152 call_options.SetTimeout(run_options.timeout_in_ms());
153 return CreateImpl(&call_options, std::move(graph));
154 }
155
ExtendImpl(CallOptions * call_options,GraphDef graph)156 Status GrpcSession::ExtendImpl(CallOptions* call_options, GraphDef graph) {
157 bool handle_is_empty;
158 {
159 mutex_lock l(mu_);
160 handle_is_empty = handle_.empty();
161 }
162 if (handle_is_empty) {
163 // Session was uninitialized, so simply initialize the session with 'graph'.
164 return Create(std::move(graph));
165 }
166 mutex_lock l(mu_);
167 ExtendSessionRequest req;
168 req.set_session_handle(handle_);
169 req.mutable_graph_def()->Swap(&graph);
170 req.set_current_graph_version(current_graph_version_);
171 ExtendSessionResponse resp;
172 Status s = master_->ExtendSession(call_options, &req, &resp);
173 if (s.ok()) {
174 current_graph_version_ = resp.new_graph_version();
175 }
176 return s;
177 }
178
Extend(const GraphDef & graph)179 Status GrpcSession::Extend(const GraphDef& graph) {
180 return Extend(GraphDef(graph));
181 }
182
Extend(const RunOptions & run_options,const GraphDef & graph)183 Status GrpcSession::Extend(const RunOptions& run_options,
184 const GraphDef& graph) {
185 return Extend(run_options, GraphDef(graph));
186 }
187
Extend(GraphDef && graph)188 Status GrpcSession::Extend(GraphDef&& graph) {
189 CallOptions call_options;
190 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
191 return ExtendImpl(&call_options, std::move(graph));
192 }
193
Extend(const RunOptions & run_options,GraphDef && graph)194 Status GrpcSession::Extend(const RunOptions& run_options, GraphDef&& graph) {
195 CallOptions call_options;
196 call_options.SetTimeout(run_options.timeout_in_ms());
197 return ExtendImpl(&call_options, std::move(graph));
198 }
199
RunHelper(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata,const string & prun_handle)200 Status GrpcSession::RunHelper(
201 const RunOptions& run_options,
202 const std::vector<std::pair<string, Tensor>>& inputs,
203 const std::vector<string>& output_tensor_names,
204 const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
205 RunMetadata* run_metadata, const string& prun_handle) {
206 // Convert to proto
207 std::unique_ptr<MutableRunStepRequestWrapper> req(
208 master_->CreateRunStepRequest());
209 std::unique_ptr<MutableRunStepResponseWrapper> resp(
210 master_->CreateRunStepResponse());
211
212 *req->mutable_options() = run_options;
213
214 if (run_options.timeout_in_ms() == 0) {
215 req->mutable_options()->set_timeout_in_ms(
216 options_.config.operation_timeout_in_ms());
217 }
218
219 if (!prun_handle.empty()) {
220 req->set_partial_run_handle(prun_handle);
221 }
222
223 for (const auto& it : inputs) {
224 req->add_feed(it.first, it.second);
225 }
226
227 // Support long error messages by storing the error code in the response body.
228 req->set_store_errors_in_response_body(true);
229
230 // Build an index from fetch tensor name to first index in
231 // output_tensor_names.
232 std::unordered_map<string, int> output_name_to_offset;
233 for (int i = 0, end = output_tensor_names.size(); i < end; ++i) {
234 const string& name = output_tensor_names[i];
235 if (output_name_to_offset.insert(std::make_pair(name, i)).second) {
236 req->add_fetch(name);
237 }
238 }
239 for (const string& target : target_node_names) {
240 req->add_target(target);
241 }
242
243 CallOptions call_options;
244 call_options.SetTimeout(req->options().timeout_in_ms());
245 TF_RETURN_IF_ERROR(RunProto(&call_options, req.get(), resp.get()));
246
247 // Look for an extended error returned in the response body.
248 if (resp->status_code() != error::Code::OK) {
249 return resp->status();
250 }
251
252 if (!output_tensor_names.empty()) {
253 outputs->resize(output_tensor_names.size());
254 }
255
256 // Convert response back to Tensors in the correct order.
257 for (size_t i = 0; i < resp->num_tensors(); ++i) {
258 auto fetch_it = output_name_to_offset.find(resp->tensor_name(i));
259 if (fetch_it == output_name_to_offset.end()) {
260 return errors::Internal("Received response for unrequested fetch: ",
261 resp->tensor_name(i));
262 }
263
264 Tensor output;
265 TF_RETURN_IF_ERROR(resp->TensorValue(i, &output));
266 (*outputs)[fetch_it->second] = output;
267 }
268 // In the unlikely event that output_tensor_names contains duplicates, fill in
269 // the duplicate values.
270 if (output_name_to_offset.size() != output_tensor_names.size()) {
271 for (int i = 0, end = output_tensor_names.size(); i < end; ++i) {
272 const string& name = output_tensor_names[i];
273 int offset = output_name_to_offset[name];
274 if (offset != i) {
275 (*outputs)[i] = (*outputs)[offset];
276 }
277 }
278 }
279
280 if (run_metadata) {
281 run_metadata->Swap(resp->mutable_metadata());
282 }
283
284 return OkStatus();
285 }
286
Run(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata)287 Status GrpcSession::Run(const RunOptions& run_options,
288 const std::vector<std::pair<string, Tensor>>& inputs,
289 const std::vector<string>& output_tensor_names,
290 const std::vector<string>& target_node_names,
291 std::vector<Tensor>* outputs,
292 RunMetadata* run_metadata) {
293 return RunHelper(run_options, inputs, output_tensor_names, target_node_names,
294 outputs, run_metadata, /* prun_handle */ "");
295 }
296
Run(const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs)297 Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
298 const std::vector<string>& output_tensor_names,
299 const std::vector<string>& target_node_names,
300 std::vector<Tensor>* outputs) {
301 RunOptions run_options;
302 run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms());
303 return Run(run_options, inputs, output_tensor_names, target_node_names,
304 outputs, nullptr);
305 }
306
RunProto(CallOptions * call_options,MutableRunStepRequestWrapper * req,MutableRunStepResponseWrapper * resp)307 Status GrpcSession::RunProto(CallOptions* call_options,
308 MutableRunStepRequestWrapper* req,
309 MutableRunStepResponseWrapper* resp) {
310 string handle;
311 TF_RETURN_IF_ERROR(Handle(&handle));
312 req->set_session_handle(handle);
313 return master_->RunStep(call_options, req, resp);
314 }
315
PRunSetup(const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)316 Status GrpcSession::PRunSetup(const std::vector<string>& input_names,
317 const std::vector<string>& output_names,
318 const std::vector<string>& target_nodes,
319 string* handle) {
320 // Convert to proto
321 PartialRunSetupRequest req;
322 PartialRunSetupResponse resp;
323 CallOptions call_options;
324 TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
325 for (const string& feed : input_names) {
326 req.add_feed(feed);
327 }
328 for (const string& fetch : output_names) {
329 req.add_fetch(fetch);
330 }
331 for (const string& target : target_nodes) {
332 req.add_target(target);
333 }
334 if (!is_local_) req.set_request_id(GetUniqueRequestId());
335 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
336 TF_RETURN_IF_ERROR(master_->PartialRunSetup(&call_options, &req, &resp));
337 *handle = resp.partial_run_handle();
338 return OkStatus();
339 }
340
PRun(const string & handle,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)341 Status GrpcSession::PRun(const string& handle,
342 const std::vector<std::pair<string, Tensor>>& inputs,
343 const std::vector<string>& output_names,
344 std::vector<Tensor>* outputs) {
345 RunOptions run_options;
346 run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms());
347 return RunHelper(run_options, inputs, output_names, /* targets */ {}, outputs,
348 /* run_metadata */ nullptr, handle);
349 }
350
Close()351 Status GrpcSession::Close() {
352 CloseSessionRequest req;
353 {
354 mutex_lock l(mu_);
355 if (handle_.empty()) {
356 return OkStatus();
357 }
358 req.set_session_handle(handle_);
359 handle_.clear();
360 }
361 CloseSessionResponse resp;
362 CallOptions call_options;
363 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
364 return master_->CloseSession(&call_options, &req, &resp);
365 }
366
ListDevices(std::vector<DeviceAttributes> * response)367 Status GrpcSession::ListDevices(std::vector<DeviceAttributes>* response) {
368 ListDevicesRequest req;
369 {
370 mutex_lock l(mu_);
371 req.set_session_handle(handle_);
372 }
373 if (req.session_handle().empty()) {
374 LOG(WARNING) << "GrpcSession::ListDevices will initialize the session with "
375 "an empty graph and other defaults because the session has "
376 "not yet been created.";
377 GraphDef graph_def;
378 TF_RETURN_IF_ERROR(Create(graph_def));
379 {
380 mutex_lock l(mu_);
381 req.set_session_handle(handle_);
382 }
383 }
384 ListDevicesResponse resp;
385 CallOptions call_options;
386 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
387 Status s = master_->ListDevices(&call_options, &req, &resp);
388 if (!s.ok()) {
389 LOG(ERROR) << "Could not list devices: " << s;
390 return s;
391 }
392
393 response->clear();
394 response->reserve(resp.local_device_size() + resp.remote_device_size());
395 for (const auto& device_attr : resp.local_device()) {
396 response->emplace_back(device_attr);
397 }
398 for (const auto& device_attr : resp.remote_device()) {
399 response->emplace_back(device_attr);
400 }
401 return OkStatus();
402 }
403
SetRemoteMaster(std::unique_ptr<MasterInterface> master)404 void GrpcSession::SetRemoteMaster(std::unique_ptr<MasterInterface> master) {
405 master_ = std::move(master);
406 }
407
408 // Static method.
Reset(const SessionOptions & options,const std::vector<string> & containers)409 Status GrpcSession::Reset(const SessionOptions& options,
410 const std::vector<string>& containers) {
411 SharedGrpcChannelPtr master_channel;
412 TF_RETURN_IF_ERROR(
413 NewHostPortGrpcChannel(options.target.substr(kSchemePrefixLength),
414 /*rpc_options=*/nullptr, &master_channel));
415 auto master = NewGrpcMaster(master_channel);
416 ResetRequest req;
417 req.mutable_container()->Reserve(containers.size());
418 for (const auto& c : containers) req.add_container(c);
419 ResetResponse resp;
420 CallOptions call_options;
421 call_options.SetTimeout(options.config.operation_timeout_in_ms());
422 Status ret = master->Reset(&call_options, &req, &resp);
423 delete master;
424 return ret;
425 }
426
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)427 Status GrpcSession::MakeCallable(const CallableOptions& callable_options,
428 CallableHandle* out_handle) {
429 MakeCallableRequest req;
430 TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
431 *req.mutable_options() = callable_options;
432 if (!is_local_) req.set_request_id(GetUniqueRequestId());
433 MakeCallableResponse resp;
434 CallOptions call_options;
435 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
436 TF_RETURN_IF_ERROR(master_->MakeCallable(&call_options, &req, &resp));
437 *out_handle = resp.handle();
438 return OkStatus();
439 }
440
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)441 Status GrpcSession::RunCallable(CallableHandle handle,
442 const std::vector<Tensor>& feed_tensors,
443 std::vector<Tensor>* fetch_tensors,
444 RunMetadata* run_metadata) {
445 RunCallableRequest req;
446 TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
447 req.set_handle(handle);
448 if (!is_local_) req.set_request_id(GetUniqueRequestId());
449 for (const Tensor& feed : feed_tensors) {
450 feed.AsProtoTensorContent(req.mutable_feed()->Add());
451 }
452
453 RunCallableResponse resp;
454 CallOptions call_options;
455 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
456 TF_RETURN_IF_ERROR(master_->RunCallable(&call_options, &req, &resp));
457 for (const TensorProto& fetch : resp.fetch()) {
458 Tensor fetch_tensor;
459 if (!fetch_tensor.FromProto(cpu_allocator(), fetch)) {
460 return errors::Internal(
461 "Could not parse fetched tensor data in response from master.");
462 }
463 fetch_tensors->push_back(std::move(fetch_tensor));
464 }
465 return OkStatus();
466 }
467
ReleaseCallable(CallableHandle handle)468 Status GrpcSession::ReleaseCallable(CallableHandle handle) {
469 ReleaseCallableRequest req;
470 TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
471 req.set_handle(handle);
472 ReleaseCallableResponse resp;
473 CallOptions call_options;
474 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
475 return master_->ReleaseCallable(&call_options, &req, &resp);
476 }
477
478 class GrpcSessionFactory : public SessionFactory {
479 public:
AcceptsOptions(const SessionOptions & options)480 bool AcceptsOptions(const SessionOptions& options) override {
481 return absl::StartsWith(options.target, kSchemePrefix);
482 }
483
NewSession(const SessionOptions & options,Session ** out_session)484 Status NewSession(const SessionOptions& options,
485 Session** out_session) override {
486 std::unique_ptr<GrpcSession> session;
487 TF_RETURN_IF_ERROR(GrpcSession::Create(options, &session));
488 *out_session = session.release();
489 return OkStatus();
490 }
491
492 // Invokes the session specific static method to reset containers.
Reset(const SessionOptions & options,const std::vector<string> & containers)493 Status Reset(const SessionOptions& options,
494 const std::vector<string>& containers) override {
495 return GrpcSession::Reset(options, containers);
496 }
497 };
498
499 class GrpcSessionRegistrar {
500 public:
GrpcSessionRegistrar()501 GrpcSessionRegistrar() {
502 SessionFactory::Register("GRPC_SESSION", new GrpcSessionFactory());
503 }
504 };
505 static GrpcSessionRegistrar registrar;
506
507 } // namespace tensorflow
508