xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/rpc/grpc_session.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
17 
18 #include <unordered_map>
19 
20 #include "tensorflow/core/common_runtime/session_factory.h"
21 #include "tensorflow/core/distributed_runtime/call_options.h"
22 #include "tensorflow/core/distributed_runtime/local_master.h"
23 #include "tensorflow/core/distributed_runtime/master_interface.h"
24 #include "tensorflow/core/distributed_runtime/request_id.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/tensor.pb.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/protobuf/master.pb.h"
34 
35 namespace tensorflow {
36 
37 const char* const kSchemePrefix = "grpc://";
38 const size_t kSchemePrefixLength = strlen(kSchemePrefix);
39 
GrpcSession(const SessionOptions & options)40 GrpcSession::GrpcSession(const SessionOptions& options)
41     : options_(options), current_graph_version_(-1) {}
42 
~GrpcSession()43 GrpcSession::~GrpcSession() {}
44 
45 /* static */
Create(const SessionOptions & options,std::unique_ptr<GrpcSession> * out_session)46 Status GrpcSession::Create(const SessionOptions& options,
47                            std::unique_ptr<GrpcSession>* out_session) {
48   std::unique_ptr<GrpcSession> session(new GrpcSession(options));
49   std::unique_ptr<MasterInterface> master;
50   // For testing, we enable the client to disable the use of the local
51   // master registry, so that the RPC stack is exercised.
52   if (!options.config.rpc_options().use_rpc_for_inprocess_master()) {
53     master = LocalMaster::Lookup(options.target);
54   }
55   if (!master) {
56     SharedGrpcChannelPtr master_channel;
57     TF_RETURN_IF_ERROR(
58         NewHostPortGrpcChannel(options.target.substr(kSchemePrefixLength),
59                                &options.config.rpc_options(), &master_channel));
60     master.reset(NewGrpcMaster(master_channel));
61   } else {
62     session->is_local_ = true;
63   }
64   session->SetRemoteMaster(std::move(master));
65   *out_session = std::move(session);
66   return OkStatus();
67 }
68 
69 namespace {
70 // Re-encodes constant represented in tensor proto into
71 // tensor_content, which is slightly better (less copies and lower peak
72 // memory usage) when used with rpc subsystems.
ReEncodeConsts(GraphDef * gdef)73 void ReEncodeConsts(GraphDef* gdef) {
74   for (NodeDef& ndef : *(gdef->mutable_node())) {
75     if (ndef.op() == "Const") {
76       TensorProto* proto = nullptr;
77       for (auto& attr : *ndef.mutable_attr()) {
78         if (attr.first == "value") {
79           proto = attr.second.mutable_tensor();
80         }
81       }
82       if (proto != nullptr && proto->tensor_content().empty() &&
83           proto->ByteSizeLong() > 64) {
84         // If the constant is encoded with repeated proto fields and
85         // it is moderate large, we re-encode it in tensor_content as
86         // a Cord. This is mildly helpful for reducing the peak memory
87         // usage on the server side where GraphDef/NodeDef are copied
88         // quite often.
89         Tensor parsed(proto->dtype());
90         if (parsed.FromProto(*proto)) {
91           parsed.AsProtoTensorContent(proto);
92         }
93       }
94     }
95   }
96 }
97 }  // namespace
98 
SetHandleAndGraphVersion(string handle,int64_t graph_version)99 void GrpcSession::SetHandleAndGraphVersion(string handle,
100                                            int64_t graph_version) {
101   mutex_lock l(mu_);
102   handle_ = std::move(handle);
103   current_graph_version_ = graph_version;
104 }
105 
Handle(string * out_handle)106 Status GrpcSession::Handle(string* out_handle) {
107   mutex_lock l(mu_);
108   if (handle_.empty()) {
109     return errors::InvalidArgument("A session is not created yet....");
110   }
111   *out_handle = handle_;
112   return OkStatus();
113 }
114 
CreateImpl(CallOptions * call_options,GraphDef graph)115 Status GrpcSession::CreateImpl(CallOptions* call_options, GraphDef graph) {
116   {
117     mutex_lock l(mu_);
118     if (!handle_.empty()) {
119       return errors::InvalidArgument("A session is alive.");
120     }
121   }
122   CreateSessionRequest req;
123   *req.mutable_config() = options_.config;
124   req.mutable_graph_def()->Swap(&graph);
125   req.set_target(options_.target);
126   ReEncodeConsts(req.mutable_graph_def());
127   CreateSessionResponse resp;
128   Status s = master_->CreateSession(call_options, &req, &resp);
129   if (s.ok()) {
130     SetHandleAndGraphVersion(resp.session_handle(), resp.graph_version());
131   }
132   return s;
133 }
134 
Create(const GraphDef & graph)135 Status GrpcSession::Create(const GraphDef& graph) {
136   return Create(GraphDef(graph));
137 }
138 
Create(const RunOptions & run_options,const GraphDef & graph)139 Status GrpcSession::Create(const RunOptions& run_options,
140                            const GraphDef& graph) {
141   return Create(run_options, GraphDef(graph));
142 }
143 
Create(GraphDef && graph)144 Status GrpcSession::Create(GraphDef&& graph) {
145   CallOptions call_options;
146   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
147   return CreateImpl(&call_options, std::move(graph));
148 }
149 
Create(const RunOptions & run_options,GraphDef && graph)150 Status GrpcSession::Create(const RunOptions& run_options, GraphDef&& graph) {
151   CallOptions call_options;
152   call_options.SetTimeout(run_options.timeout_in_ms());
153   return CreateImpl(&call_options, std::move(graph));
154 }
155 
ExtendImpl(CallOptions * call_options,GraphDef graph)156 Status GrpcSession::ExtendImpl(CallOptions* call_options, GraphDef graph) {
157   bool handle_is_empty;
158   {
159     mutex_lock l(mu_);
160     handle_is_empty = handle_.empty();
161   }
162   if (handle_is_empty) {
163     // Session was uninitialized, so simply initialize the session with 'graph'.
164     return Create(std::move(graph));
165   }
166   mutex_lock l(mu_);
167   ExtendSessionRequest req;
168   req.set_session_handle(handle_);
169   req.mutable_graph_def()->Swap(&graph);
170   req.set_current_graph_version(current_graph_version_);
171   ExtendSessionResponse resp;
172   Status s = master_->ExtendSession(call_options, &req, &resp);
173   if (s.ok()) {
174     current_graph_version_ = resp.new_graph_version();
175   }
176   return s;
177 }
178 
Extend(const GraphDef & graph)179 Status GrpcSession::Extend(const GraphDef& graph) {
180   return Extend(GraphDef(graph));
181 }
182 
Extend(const RunOptions & run_options,const GraphDef & graph)183 Status GrpcSession::Extend(const RunOptions& run_options,
184                            const GraphDef& graph) {
185   return Extend(run_options, GraphDef(graph));
186 }
187 
Extend(GraphDef && graph)188 Status GrpcSession::Extend(GraphDef&& graph) {
189   CallOptions call_options;
190   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
191   return ExtendImpl(&call_options, std::move(graph));
192 }
193 
Extend(const RunOptions & run_options,GraphDef && graph)194 Status GrpcSession::Extend(const RunOptions& run_options, GraphDef&& graph) {
195   CallOptions call_options;
196   call_options.SetTimeout(run_options.timeout_in_ms());
197   return ExtendImpl(&call_options, std::move(graph));
198 }
199 
RunHelper(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata,const string & prun_handle)200 Status GrpcSession::RunHelper(
201     const RunOptions& run_options,
202     const std::vector<std::pair<string, Tensor>>& inputs,
203     const std::vector<string>& output_tensor_names,
204     const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
205     RunMetadata* run_metadata, const string& prun_handle) {
206   // Convert to proto
207   std::unique_ptr<MutableRunStepRequestWrapper> req(
208       master_->CreateRunStepRequest());
209   std::unique_ptr<MutableRunStepResponseWrapper> resp(
210       master_->CreateRunStepResponse());
211 
212   *req->mutable_options() = run_options;
213 
214   if (run_options.timeout_in_ms() == 0) {
215     req->mutable_options()->set_timeout_in_ms(
216         options_.config.operation_timeout_in_ms());
217   }
218 
219   if (!prun_handle.empty()) {
220     req->set_partial_run_handle(prun_handle);
221   }
222 
223   for (const auto& it : inputs) {
224     req->add_feed(it.first, it.second);
225   }
226 
227   // Support long error messages by storing the error code in the response body.
228   req->set_store_errors_in_response_body(true);
229 
230   // Build an index from fetch tensor name to first index in
231   // output_tensor_names.
232   std::unordered_map<string, int> output_name_to_offset;
233   for (int i = 0, end = output_tensor_names.size(); i < end; ++i) {
234     const string& name = output_tensor_names[i];
235     if (output_name_to_offset.insert(std::make_pair(name, i)).second) {
236       req->add_fetch(name);
237     }
238   }
239   for (const string& target : target_node_names) {
240     req->add_target(target);
241   }
242 
243   CallOptions call_options;
244   call_options.SetTimeout(req->options().timeout_in_ms());
245   TF_RETURN_IF_ERROR(RunProto(&call_options, req.get(), resp.get()));
246 
247   // Look for an extended error returned in the response body.
248   if (resp->status_code() != error::Code::OK) {
249     return resp->status();
250   }
251 
252   if (!output_tensor_names.empty()) {
253     outputs->resize(output_tensor_names.size());
254   }
255 
256   // Convert response back to Tensors in the correct order.
257   for (size_t i = 0; i < resp->num_tensors(); ++i) {
258     auto fetch_it = output_name_to_offset.find(resp->tensor_name(i));
259     if (fetch_it == output_name_to_offset.end()) {
260       return errors::Internal("Received response for unrequested fetch: ",
261                               resp->tensor_name(i));
262     }
263 
264     Tensor output;
265     TF_RETURN_IF_ERROR(resp->TensorValue(i, &output));
266     (*outputs)[fetch_it->second] = output;
267   }
268   // In the unlikely event that output_tensor_names contains duplicates, fill in
269   // the duplicate values.
270   if (output_name_to_offset.size() != output_tensor_names.size()) {
271     for (int i = 0, end = output_tensor_names.size(); i < end; ++i) {
272       const string& name = output_tensor_names[i];
273       int offset = output_name_to_offset[name];
274       if (offset != i) {
275         (*outputs)[i] = (*outputs)[offset];
276       }
277     }
278   }
279 
280   if (run_metadata) {
281     run_metadata->Swap(resp->mutable_metadata());
282   }
283 
284   return OkStatus();
285 }
286 
Run(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata)287 Status GrpcSession::Run(const RunOptions& run_options,
288                         const std::vector<std::pair<string, Tensor>>& inputs,
289                         const std::vector<string>& output_tensor_names,
290                         const std::vector<string>& target_node_names,
291                         std::vector<Tensor>* outputs,
292                         RunMetadata* run_metadata) {
293   return RunHelper(run_options, inputs, output_tensor_names, target_node_names,
294                    outputs, run_metadata, /* prun_handle */ "");
295 }
296 
Run(const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs)297 Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
298                         const std::vector<string>& output_tensor_names,
299                         const std::vector<string>& target_node_names,
300                         std::vector<Tensor>* outputs) {
301   RunOptions run_options;
302   run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms());
303   return Run(run_options, inputs, output_tensor_names, target_node_names,
304              outputs, nullptr);
305 }
306 
RunProto(CallOptions * call_options,MutableRunStepRequestWrapper * req,MutableRunStepResponseWrapper * resp)307 Status GrpcSession::RunProto(CallOptions* call_options,
308                              MutableRunStepRequestWrapper* req,
309                              MutableRunStepResponseWrapper* resp) {
310   string handle;
311   TF_RETURN_IF_ERROR(Handle(&handle));
312   req->set_session_handle(handle);
313   return master_->RunStep(call_options, req, resp);
314 }
315 
PRunSetup(const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)316 Status GrpcSession::PRunSetup(const std::vector<string>& input_names,
317                               const std::vector<string>& output_names,
318                               const std::vector<string>& target_nodes,
319                               string* handle) {
320   // Convert to proto
321   PartialRunSetupRequest req;
322   PartialRunSetupResponse resp;
323   CallOptions call_options;
324   TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
325   for (const string& feed : input_names) {
326     req.add_feed(feed);
327   }
328   for (const string& fetch : output_names) {
329     req.add_fetch(fetch);
330   }
331   for (const string& target : target_nodes) {
332     req.add_target(target);
333   }
334   if (!is_local_) req.set_request_id(GetUniqueRequestId());
335   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
336   TF_RETURN_IF_ERROR(master_->PartialRunSetup(&call_options, &req, &resp));
337   *handle = resp.partial_run_handle();
338   return OkStatus();
339 }
340 
PRun(const string & handle,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)341 Status GrpcSession::PRun(const string& handle,
342                          const std::vector<std::pair<string, Tensor>>& inputs,
343                          const std::vector<string>& output_names,
344                          std::vector<Tensor>* outputs) {
345   RunOptions run_options;
346   run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms());
347   return RunHelper(run_options, inputs, output_names, /* targets */ {}, outputs,
348                    /* run_metadata */ nullptr, handle);
349 }
350 
Close()351 Status GrpcSession::Close() {
352   CloseSessionRequest req;
353   {
354     mutex_lock l(mu_);
355     if (handle_.empty()) {
356       return OkStatus();
357     }
358     req.set_session_handle(handle_);
359     handle_.clear();
360   }
361   CloseSessionResponse resp;
362   CallOptions call_options;
363   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
364   return master_->CloseSession(&call_options, &req, &resp);
365 }
366 
ListDevices(std::vector<DeviceAttributes> * response)367 Status GrpcSession::ListDevices(std::vector<DeviceAttributes>* response) {
368   ListDevicesRequest req;
369   {
370     mutex_lock l(mu_);
371     req.set_session_handle(handle_);
372   }
373   if (req.session_handle().empty()) {
374     LOG(WARNING) << "GrpcSession::ListDevices will initialize the session with "
375                     "an empty graph and other defaults because the session has "
376                     "not yet been created.";
377     GraphDef graph_def;
378     TF_RETURN_IF_ERROR(Create(graph_def));
379     {
380       mutex_lock l(mu_);
381       req.set_session_handle(handle_);
382     }
383   }
384   ListDevicesResponse resp;
385   CallOptions call_options;
386   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
387   Status s = master_->ListDevices(&call_options, &req, &resp);
388   if (!s.ok()) {
389     LOG(ERROR) << "Could not list devices: " << s;
390     return s;
391   }
392 
393   response->clear();
394   response->reserve(resp.local_device_size() + resp.remote_device_size());
395   for (const auto& device_attr : resp.local_device()) {
396     response->emplace_back(device_attr);
397   }
398   for (const auto& device_attr : resp.remote_device()) {
399     response->emplace_back(device_attr);
400   }
401   return OkStatus();
402 }
403 
SetRemoteMaster(std::unique_ptr<MasterInterface> master)404 void GrpcSession::SetRemoteMaster(std::unique_ptr<MasterInterface> master) {
405   master_ = std::move(master);
406 }
407 
408 // Static method.
Reset(const SessionOptions & options,const std::vector<string> & containers)409 Status GrpcSession::Reset(const SessionOptions& options,
410                           const std::vector<string>& containers) {
411   SharedGrpcChannelPtr master_channel;
412   TF_RETURN_IF_ERROR(
413       NewHostPortGrpcChannel(options.target.substr(kSchemePrefixLength),
414                              /*rpc_options=*/nullptr, &master_channel));
415   auto master = NewGrpcMaster(master_channel);
416   ResetRequest req;
417   req.mutable_container()->Reserve(containers.size());
418   for (const auto& c : containers) req.add_container(c);
419   ResetResponse resp;
420   CallOptions call_options;
421   call_options.SetTimeout(options.config.operation_timeout_in_ms());
422   Status ret = master->Reset(&call_options, &req, &resp);
423   delete master;
424   return ret;
425 }
426 
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)427 Status GrpcSession::MakeCallable(const CallableOptions& callable_options,
428                                  CallableHandle* out_handle) {
429   MakeCallableRequest req;
430   TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
431   *req.mutable_options() = callable_options;
432   if (!is_local_) req.set_request_id(GetUniqueRequestId());
433   MakeCallableResponse resp;
434   CallOptions call_options;
435   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
436   TF_RETURN_IF_ERROR(master_->MakeCallable(&call_options, &req, &resp));
437   *out_handle = resp.handle();
438   return OkStatus();
439 }
440 
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)441 Status GrpcSession::RunCallable(CallableHandle handle,
442                                 const std::vector<Tensor>& feed_tensors,
443                                 std::vector<Tensor>* fetch_tensors,
444                                 RunMetadata* run_metadata) {
445   RunCallableRequest req;
446   TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
447   req.set_handle(handle);
448   if (!is_local_) req.set_request_id(GetUniqueRequestId());
449   for (const Tensor& feed : feed_tensors) {
450     feed.AsProtoTensorContent(req.mutable_feed()->Add());
451   }
452 
453   RunCallableResponse resp;
454   CallOptions call_options;
455   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
456   TF_RETURN_IF_ERROR(master_->RunCallable(&call_options, &req, &resp));
457   for (const TensorProto& fetch : resp.fetch()) {
458     Tensor fetch_tensor;
459     if (!fetch_tensor.FromProto(cpu_allocator(), fetch)) {
460       return errors::Internal(
461           "Could not parse fetched tensor data in response from master.");
462     }
463     fetch_tensors->push_back(std::move(fetch_tensor));
464   }
465   return OkStatus();
466 }
467 
ReleaseCallable(CallableHandle handle)468 Status GrpcSession::ReleaseCallable(CallableHandle handle) {
469   ReleaseCallableRequest req;
470   TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
471   req.set_handle(handle);
472   ReleaseCallableResponse resp;
473   CallOptions call_options;
474   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
475   return master_->ReleaseCallable(&call_options, &req, &resp);
476 }
477 
478 class GrpcSessionFactory : public SessionFactory {
479  public:
AcceptsOptions(const SessionOptions & options)480   bool AcceptsOptions(const SessionOptions& options) override {
481     return absl::StartsWith(options.target, kSchemePrefix);
482   }
483 
NewSession(const SessionOptions & options,Session ** out_session)484   Status NewSession(const SessionOptions& options,
485                     Session** out_session) override {
486     std::unique_ptr<GrpcSession> session;
487     TF_RETURN_IF_ERROR(GrpcSession::Create(options, &session));
488     *out_session = session.release();
489     return OkStatus();
490   }
491 
492   // Invokes the session specific static method to reset containers.
Reset(const SessionOptions & options,const std::vector<string> & containers)493   Status Reset(const SessionOptions& options,
494                const std::vector<string>& containers) override {
495     return GrpcSession::Reset(options, containers);
496   }
497 };
498 
499 class GrpcSessionRegistrar {
500  public:
GrpcSessionRegistrar()501   GrpcSessionRegistrar() {
502     SessionFactory::Register("GRPC_SESSION", new GrpcSessionFactory());
503   }
504 };
505 static GrpcSessionRegistrar registrar;
506 
507 }  // namespace tensorflow
508