xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/message_wrappers.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
18 
19 #include "tensorflow/core/framework/allocator.h"
20 #include "tensorflow/core/framework/cost_graph.pb.h"
21 #include "tensorflow/core/framework/graph.pb.h"
22 #include "tensorflow/core/framework/step_stats.pb.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor.pb.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/protobuf/config.pb.h"
27 #include "tensorflow/core/protobuf/master.pb.h"
28 #include "tensorflow/core/protobuf/worker.pb.h"
29 
30 namespace tensorflow {
31 
32 ////////////////////////////////////////////////////////////////////////////////
33 //
34 // Wrapper classes for the `MasterService.RunStep` request message.
35 //
36 // The `RunStepRequest` message can contain potentially large tensor
37 // data as part of its `feed` submessages. Here we provide specialized
38 // wrappers that avoid copying the tensor data wherever possible.
39 //
40 // See `RunStepRequest` in tensorflow/core/protobuf/master.proto for the
41 // protocol buffer definition.
42 //
43 ////////////////////////////////////////////////////////////////////////////////
44 
45 // Abstract interface for an immutable RunStepRequest message.
46 //
47 // This interface is typically used by server-side components in the
48 // TensorFlow master.
49 class RunStepRequestWrapper {
50  public:
~RunStepRequestWrapper()51   virtual ~RunStepRequestWrapper() {}
52 
53   // REQUIRED: session_handle must be returned by a CreateSession call
54   // to the same master service.
55   virtual const string& session_handle() const = 0;
56 
57   // Partial run handle (optional). If specified, this will be a partial run
58   // execution, run up to the specified fetches.
59   virtual const string& partial_run_handle() const = 0;
60 
61   // Tensors to be fed in the step. Each feed is a named tensor.
62   virtual size_t num_feeds() const = 0;
63   virtual const string& feed_name(size_t i) const = 0;
64 
65   // Stores the content of the feed value at index `i` in `tensor`.
66   virtual Status FeedValue(size_t i, Tensor* out_tensor) const = 0;
67   virtual Status FeedValue(size_t i, TensorProto* out_tensor) const = 0;
68 
69   // Fetches. A list of tensor names. The caller expects a tensor to
70   // be returned for each fetch[i] (see RunStepResponse.tensor). The
71   // order of specified fetches does not change the execution order.
72   virtual size_t num_fetches() const = 0;
73   virtual const string& fetch_name(size_t i) const = 0;
74 
75   // Target Nodes. A list of node names. The named nodes will be run
76   // to but their outputs will not be fetched.
77   virtual size_t num_targets() const = 0;
78   virtual const string& target_name(size_t i) const = 0;
79 
80   // Options for the run call.
81   virtual const RunOptions& options() const = 0;
82 
83   // If true then some errors, e.g., execution errors that have long
84   // error messages, may return an OK RunStepResponse with the actual
85   // error saved in the status_code/status_error_message fields of the
86   // response body. This is a workaround since the RPC subsystem may
87   // truncate long metadata messages.
88   virtual bool store_errors_in_response_body() const = 0;
89 
90   // Unique identifier for this request. Every RunGraphRequest must have a
91   // unique request_id, and retried RunGraphRequests must have the same
92   // request_id. If request_id is zero, retry detection is disabled.
93   virtual int64_t request_id() const = 0;
94 
95   // Returns a human-readable representation of this message for debugging.
96   virtual string DebugString() const = 0;
97 
98   // Returns the wrapped data as a protocol buffer message.
99   virtual const RunStepRequest& ToProto() const = 0;
100 };
101 
102 // Abstract interface for a mutable RunStepRequest message.
103 //
104 // See `RunStepRequestWrapper` above for a description of the fields.
105 class MutableRunStepRequestWrapper : public RunStepRequestWrapper {
106  public:
107   virtual void set_session_handle(const string& handle) = 0;
108   virtual void set_partial_run_handle(const string& handle) = 0;
109   virtual void add_feed(const string& name, const Tensor& value) = 0;
110   virtual void add_fetch(const string& name) = 0;
111   virtual void add_target(const string& name) = 0;
112   virtual RunOptions* mutable_options() = 0;
113   virtual void set_store_errors_in_response_body(bool store_errors) = 0;
114 };
115 
116 // Specialized (and mutable) wrapper for RunStep requests between a client and
117 // master in the same address space.
118 class InMemoryRunStepRequest : public MutableRunStepRequestWrapper {
119  public:
120   // RunStepRequestWrapper methods.
121   const string& session_handle() const override;
122   const string& partial_run_handle() const override;
123   size_t num_feeds() const override;
124   const string& feed_name(size_t i) const override;
125   Status FeedValue(size_t i, Tensor* out_tensor) const override;
126   Status FeedValue(size_t i, TensorProto* out_tensor) const override;
127   size_t num_fetches() const override;
128   const string& fetch_name(size_t i) const override;
129   size_t num_targets() const override;
130   const string& target_name(size_t i) const override;
131   const RunOptions& options() const override;
132   string DebugString() const override;
133   const RunStepRequest& ToProto() const override;
134   bool store_errors_in_response_body() const override;
135   int64_t request_id() const override;
136 
137   // MutableRunStepRequestWrapper methods.
138   void set_session_handle(const string& handle) override;
139   void set_partial_run_handle(const string& handle) override;
140   void add_feed(const string& name, const Tensor& value) override;
141   void add_fetch(const string& name) override;
142   void add_target(const string& name) override;
143   RunOptions* mutable_options() override;
144   void set_store_errors_in_response_body(bool store_errors) override;
145 
146  private:
147   string session_handle_;
148   string partial_run_handle_;
149   gtl::InlinedVector<std::pair<string, Tensor>, 4> feeds_;
150   gtl::InlinedVector<string, 4> fetches_;
151   gtl::InlinedVector<string, 4> targets_;
152   RunOptions options_;
153   bool store_errors_in_response_body_ = false;
154 
155   // Holds a cached and owned representation of the proto
156   // representation of this request, if needed, so that `ToProto()`
157   // can return a const RunStepRequest&.
158   // NOTE(mrry): Although calls to `ToProto()` on this class are
159   // expected to be rare, retaining ownership of the returned message
160   // makes it easier to return a reference from the proto-backed
161   // representations.
162   mutable std::unique_ptr<RunStepRequest> proto_version_;
163 };
164 
165 // Wrapper for mutable RunStep requests that uses a protobuf message.
166 //
167 // This wrapper class should be used for RunStep requests between a
168 // client and master in different address spaces.
169 class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper {
170  public:
171   // RunStepRequestWrapper methods.
172   const string& session_handle() const override;
173   const string& partial_run_handle() const override;
174   size_t num_feeds() const override;
175   const string& feed_name(size_t i) const override;
176   Status FeedValue(size_t i, Tensor* out_tensor) const override;
177   Status FeedValue(size_t i, TensorProto* out_tensor) const override;
178   size_t num_fetches() const override;
179   const string& fetch_name(size_t i) const override;
180   size_t num_targets() const override;
181   const string& target_name(size_t i) const override;
182   const RunOptions& options() const override;
183   string DebugString() const override;
184   const RunStepRequest& ToProto() const override;
185   bool store_errors_in_response_body() const override;
186   int64_t request_id() const override;
187 
188   // MutableRunStepRequestWrapper methods.
189   void set_session_handle(const string& handle) override;
190   void set_partial_run_handle(const string& handle) override;
191   void add_feed(const string& name, const Tensor& value) override;
192   void add_fetch(const string& name) override;
193   void add_target(const string& name) override;
194   RunOptions* mutable_options() override;
195   void set_store_errors_in_response_body(bool store_errors) override;
196 
197  private:
198   RunStepRequest request_;
199   friend class MasterInterface;
200 };
201 
202 // Wrapper for immutable RunStep requests that use a non-owned
203 // protobuf message.
204 //
205 // This interface is typically used by server-side components in the
206 // TensorFlow master, where the incoming message is a (possibly const)
207 // `RunStepRequest*`.
208 class ProtoRunStepRequest : public RunStepRequestWrapper {
209  public:
210   ProtoRunStepRequest(const RunStepRequest* request);
211 
212   // RunStepRequestWrapper methods.
213   const string& session_handle() const override;
214   const string& partial_run_handle() const override;
215   size_t num_feeds() const override;
216   const string& feed_name(size_t i) const override;
217   Status FeedValue(size_t i, Tensor* out_tensor) const override;
218   Status FeedValue(size_t i, TensorProto* out_tensor) const override;
219   size_t num_fetches() const override;
220   const string& fetch_name(size_t i) const override;
221   size_t num_targets() const override;
222   const string& target_name(size_t i) const override;
223   const RunOptions& options() const override;
224   string DebugString() const override;
225   const RunStepRequest& ToProto() const override;
226   bool store_errors_in_response_body() const override;
227   int64_t request_id() const override;
228 
229  private:
230   const RunStepRequest* const request_;  // Not owned.
231 };
232 
233 ////////////////////////////////////////////////////////////////////////////////
234 //
235 // Wrapper classes for the `WorkerService.RunGraph` request message.
236 //
237 // The `RunGraphRequest` message can contain potentially large tensor
238 // data as part of its `send` submessages. Here we provide specialized
239 // wrappers that avoid copying the tensor data wherever possible.
240 //
241 // See `RunGraphRequest` in tensorflow/core/protobuf/worker.proto for the
242 // protocol buffer definition.
243 //
244 ////////////////////////////////////////////////////////////////////////////////
245 
246 // Abstract interface for an immutable RunGraphRequest message.
247 //
248 // This interface is typically used by server-side components in the
249 // TensorFlow worker.
250 class RunGraphRequestWrapper {
251  public:
~RunGraphRequestWrapper()252   virtual ~RunGraphRequestWrapper() {}
253 
254   // The session handle used to register the graph. If empty, a single global
255   // namespace is used.
256   virtual const string& session_handle() const = 0;
257 
258   // Set to true if `CreateWorkerSession` was called for `session_handle`.
259   virtual bool create_worker_session_called() const = 0;
260 
261   // REQUIRED: graph_handle must be returned by a RegisterGraph call
262   // to the same WorkerService.
263   virtual const string& graph_handle() const = 0;
264 
265   // A unique ID to distinguish different runs of the same graph.
266   //
267   // The master generates a global unique `step_id` to distinguish
268   // different runs of the graph computation. Subgraphs communicate
269   // (e.g., send/recv ops) with each other using `step_id` to
270   // distinguish tensors generated by different runs.
271   virtual int64_t step_id() const = 0;
272 
273   // Options for this step.
274   virtual const ExecutorOpts& exec_opts() const = 0;
275 
276   // Sends the tensors in "send" into the graph before the run.
277   virtual size_t num_sends() const = 0;
278   virtual const string& send_key(size_t i) const = 0;
279   virtual Status SendValue(size_t i, Tensor* out_tensor) const = 0;
280 
281   // Fetches the keys into `RunGraphResponse.recv` after the run.
282   virtual size_t num_recvs() const = 0;
283   virtual const string& recv_key(size_t i) const = 0;
284 
285   // True if the RunGraphRequest is a partial run request.
286   virtual bool is_partial() const = 0;
287 
288   // True if this is the last partial run request in a sequence of requests.
289   virtual bool is_last_partial_run() const = 0;
290 
291   // If true then some errors, e.g., execution errors that have long
292   // error messages, may return an OK RunStepResponse with the actual
293   // error saved in the status_code/status_error_message fields of the
294   // response body. This is a workaround since the RPC subsystem may
295   // truncate long metadata messages.
296   virtual bool store_errors_in_response_body() const = 0;
297 
298   virtual int64_t request_id() const = 0;
299 
300   // Returns the wrapped data as a protocol buffer message.
301   virtual const RunGraphRequest& ToProto() const = 0;
302 };
303 
304 // Abstract interface for a mutable RunGraphRequest message.
305 //
306 // See `RunGraphRequestWrapper` above for a description of the fields.
307 class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
308  public:
309   virtual void set_session_handle(const string& handle) = 0;
310   virtual void set_create_worker_session_called(bool called) = 0;
311   virtual void set_graph_handle(const string& handle) = 0;
312   virtual void set_step_id(int64_t step_id) = 0;
313   virtual ExecutorOpts* mutable_exec_opts() = 0;
314 
315   // Stores the i^{th} feed value in `run_step_request` in this
316   // request with the given `send_key`.
317   virtual Status AddSendFromRunStepRequest(
318       const RunStepRequestWrapper& run_step_request, size_t i,
319       const string& send_key) = 0;
320   virtual Status AddSendFromRunCallableRequest(
321       const RunCallableRequest& run_callable_request, size_t i,
322       const string& send_key) = 0;
323 
324   virtual void add_recv_key(const string& recv_key) = 0;
325   virtual void set_is_partial(bool is_partial) = 0;
326   virtual void set_is_last_partial_run(bool is_last_partial_run) = 0;
327   virtual void set_store_errors_in_response_body(bool store_errors) = 0;
328   virtual void set_request_id(int64_t request_id) = 0;
329 };
330 
331 class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
332  public:
333   // RunGraphRequestWrapper methods.
334   const string& session_handle() const override;
335   const string& graph_handle() const override;
336   bool create_worker_session_called() const override;
337   int64_t step_id() const override;
338   const ExecutorOpts& exec_opts() const override;
339   size_t num_sends() const override;
340   const string& send_key(size_t i) const override;
341   Status SendValue(size_t i, Tensor* out_tensor) const override;
342   size_t num_recvs() const override;
343   const string& recv_key(size_t i) const override;
344   bool is_partial() const override;
345   bool is_last_partial_run() const override;
346   const RunGraphRequest& ToProto() const override;
347   bool store_errors_in_response_body() const override;
348   int64_t request_id() const override;
349 
350   // MutableRunGraphRequestWrapper methods.
351   void set_session_handle(const string& handle) override;
352   void set_create_worker_session_called(bool called) override;
353   void set_graph_handle(const string& handle) override;
354   void set_step_id(int64_t step_id) override;
355   ExecutorOpts* mutable_exec_opts() override;
356   Status AddSendFromRunStepRequest(
357       const RunStepRequestWrapper& run_step_request, size_t i,
358       const string& send_key) override;
359   Status AddSendFromRunCallableRequest(
360       const RunCallableRequest& run_callable_request, size_t i,
361       const string& send_key) override;
362   void add_recv_key(const string& recv_key) override;
363   void set_is_partial(bool is_partial) override;
364   void set_is_last_partial_run(bool is_last_partial_run) override;
365   void set_store_errors_in_response_body(bool store_errors) override;
366   void set_request_id(int64_t request_id) override;
367 
368  private:
369   string session_handle_;
370   bool create_worker_session_called_ = false;
371   string graph_handle_;
372   int64_t step_id_;
373   ExecutorOpts exec_opts_;
374   gtl::InlinedVector<std::pair<string, Tensor>, 4> sends_;
375   gtl::InlinedVector<string, 4> recvs_;
376   bool is_partial_ = false;
377   bool is_last_partial_run_ = false;
378   bool store_errors_in_response_body_ = false;
379   int64_t request_id_ = 0;
380 
381   // Holds a cached and owned representation of the proto
382   // representation of this request, if needed, so that `ToProto()`
383   // can return a const RunGraphRequest&.
384   // NOTE(mrry): Although calls to `ToProto()` on this class are
385   // expected to be rare, retaining ownership of the returned message
386   // makes it easier to return a reference from the proto-backed
387   // representations.
388   mutable std::unique_ptr<RunGraphRequest> proto_version_;
389 };
390 
391 class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
392  public:
393   // RunGraphRequestWrapper methods.
394   const string& session_handle() const override;
395   bool create_worker_session_called() const override;
396   const string& graph_handle() const override;
397   int64_t step_id() const override;
398   const ExecutorOpts& exec_opts() const override;
399   size_t num_sends() const override;
400   const string& send_key(size_t i) const override;
401   Status SendValue(size_t i, Tensor* out_tensor) const override;
402   size_t num_recvs() const override;
403   const string& recv_key(size_t i) const override;
404   bool is_partial() const override;
405   bool is_last_partial_run() const override;
406   bool store_errors_in_response_body() const override;
407   int64_t request_id() const override;
408   const RunGraphRequest& ToProto() const override;
409 
410   // MutableRunGraphRequestWrapper methods.
411   void set_session_handle(const string& handle) override;
412   void set_create_worker_session_called(bool called) override;
413   void set_graph_handle(const string& handle) override;
414   void set_step_id(int64_t step_id) override;
415   ExecutorOpts* mutable_exec_opts() override;
416   Status AddSendFromRunStepRequest(
417       const RunStepRequestWrapper& run_step_request, size_t i,
418       const string& send_key) override;
419   Status AddSendFromRunCallableRequest(
420       const RunCallableRequest& run_callable_request, size_t i,
421       const string& send_key) override;
422   void add_recv_key(const string& recv_key) override;
423   void set_is_partial(bool is_partial) override;
424   void set_is_last_partial_run(bool is_last_partial_run) override;
425   void set_store_errors_in_response_body(bool store_errors) override;
426   void set_request_id(int64_t request_id) override;
427 
428  private:
429   RunGraphRequest request_;
430 };
431 
432 class ProtoRunGraphRequest : public RunGraphRequestWrapper {
433  public:
434   ProtoRunGraphRequest(const RunGraphRequest* request);
435 
436   // RunGraphRequestWrapper methods.
437   const string& session_handle() const override;
438   bool create_worker_session_called() const override;
439   const string& graph_handle() const override;
440   int64_t step_id() const override;
441   const ExecutorOpts& exec_opts() const override;
442   size_t num_sends() const override;
443   const string& send_key(size_t i) const override;
444   Status SendValue(size_t i, Tensor* out_tensor) const override;
445   size_t num_recvs() const override;
446   const string& recv_key(size_t i) const override;
447   bool is_partial() const override;
448   bool is_last_partial_run() const override;
449   bool store_errors_in_response_body() const override;
450   int64_t request_id() const override;
451   const RunGraphRequest& ToProto() const override;
452 
453  private:
454   const RunGraphRequest* const request_;  // Not owned.
455 };
456 
457 ////////////////////////////////////////////////////////////////////////////////
458 //
459 // Wrapper classes for the `WorkerService.RunGraph` response message.
460 //
461 // The `RunGraphResponse` message can contain potentially large tensor
462 // data as part of its `recv` submessages. Here we provide specialized
463 // wrappers that avoid copying the tensor data wherever possible.
464 //
465 // See `RunGraphResponse` in tensorflow/core/protobuf/worker.proto for the
466 // protocol buffer definition.
467 //
468 ////////////////////////////////////////////////////////////////////////////////
469 
470 // Abstract interface for a mutable RunGraphResponse message.
471 //
472 // Note that there is no corresponding (immutable)
473 // RunGraphResponseWrapper class, because the RunGraphResponse object
474 // is always used as a mutable pointer.
475 class MutableRunGraphResponseWrapper {
476  public:
~MutableRunGraphResponseWrapper()477   virtual ~MutableRunGraphResponseWrapper() {}
478 
479   // A list of tensors corresponding to those requested by
480   // `RunGraphRequest.recv_key`.
481   virtual size_t num_recvs() const = 0;
482   virtual const string& recv_key(size_t i) const = 0;
483   // NOTE: The following methods may perform a destructive read, for
484   // efficiency.
485   virtual Status RecvValue(size_t i, TensorProto* out_tensor) = 0;
486   virtual Status RecvValue(size_t i, Tensor* out_tensor) = 0;
487   virtual void AddRecv(const string& key, const Tensor& value) = 0;
488 
489   // Submessages that store performance statistics about the subgraph
490   // execution, if necessary.
491   virtual StepStats* mutable_step_stats() = 0;
492   virtual CostGraphDef* mutable_cost_graph() = 0;
493   virtual size_t num_partition_graphs() const = 0;
494   virtual GraphDef* mutable_partition_graph(size_t i) = 0;
495   virtual void AddPartitionGraph(const GraphDef& partition_graph) = 0;
496 
497   // Returned status if requested.
498   virtual Status status() const = 0;
499   virtual errors::Code status_code() const = 0;
500   virtual const string& status_error_message() const = 0;
501   virtual void set_status(const Status& status) = 0;
502 
503  protected:
504   // Returns a mutable protobuf message that represents the contents of
505   // this wrapper, for passing to an RPC subsystem that will populate
506   // the message.
507   //
508   // NOTE: Only `WorkerInterface` subclasses may call this method. The
509   // `InMemoryRunGraphResponse` subclass does not implement this
510   // method, and attempts to call it will fail with a fatal
511   // error. However, as long as callers always call
512   // `WorkerInterface::RunGraphAsync()` with a wrapper object returned
513   // from `WorkerInterface::CreateRunGraphResponse()` called on the
514   // *same* WorkerInterface object, this error will never trigger.
515   virtual RunGraphResponse* get_proto() = 0;
516   friend class WorkerInterface;
517 };
518 
519 class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper {
520  public:
521   // MutableRunGraphResponseWrapper methods.
522   size_t num_recvs() const override;
523   const string& recv_key(size_t i) const override;
524   Status RecvValue(size_t i, TensorProto* out_tensor) override;
525   Status RecvValue(size_t i, Tensor* out_tensor) override;
526   void AddRecv(const string& key, const Tensor& value) override;
527   StepStats* mutable_step_stats() override;
528   CostGraphDef* mutable_cost_graph() override;
529   size_t num_partition_graphs() const override;
530   GraphDef* mutable_partition_graph(size_t i) override;
531   void AddPartitionGraph(const GraphDef& partition_graph) override;
532   Status status() const override;
533   errors::Code status_code() const override;
534   const string& status_error_message() const override;
535   void set_status(const Status& status) override;
536 
537  protected:
538   // NOTE: This method is not implemented. See
539   // MutableRunGraphResponseWrapper for an explanation.
540   RunGraphResponse* get_proto() override;
541 
542  private:
543   gtl::InlinedVector<std::pair<string, Tensor>, 4> recvs_;
544   StepStats step_stats_;
545   CostGraphDef cost_graph_;
546   std::vector<GraphDef> partition_graphs_;
547   // Store the code and message separately so that they can be updated
548   // independently by setters.
549   Status status_;
550 };
551 
552 // Proto-based message wrapper for use on the client side of the RunGraph RPC.
553 class OwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper {
554  public:
555   // MutableRunGraphResponseWrapper methods.
556   size_t num_recvs() const override;
557   const string& recv_key(size_t i) const override;
558   Status RecvValue(size_t i, TensorProto* out_tensor) override;
559   Status RecvValue(size_t i, Tensor* out_tensor) override;
560   void AddRecv(const string& key, const Tensor& value) override;
561   StepStats* mutable_step_stats() override;
562   CostGraphDef* mutable_cost_graph() override;
563   size_t num_partition_graphs() const override;
564   GraphDef* mutable_partition_graph(size_t i) override;
565   void AddPartitionGraph(const GraphDef& partition_graph) override;
566   Status status() const override;
567   errors::Code status_code() const override;
568   const string& status_error_message() const override;
569   void set_status(const Status& status) override;
570 
571  protected:
572   RunGraphResponse* get_proto() override;
573 
574  private:
575   RunGraphResponse response_;
576 };
577 
578 // Proto-based message wrapper for use on the server side of the RunGraph RPC.
579 class NonOwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper {
580  public:
581   NonOwnedProtoRunGraphResponse(RunGraphResponse* response);
582 
583   // MutableRunGraphResponseWrapper methods.
584   size_t num_recvs() const override;
585   const string& recv_key(size_t i) const override;
586   Status RecvValue(size_t i, TensorProto* out_tensor) override;
587   Status RecvValue(size_t i, Tensor* out_tensor) override;
588   void AddRecv(const string& key, const Tensor& value) override;
589   StepStats* mutable_step_stats() override;
590   CostGraphDef* mutable_cost_graph() override;
591   size_t num_partition_graphs() const override;
592   GraphDef* mutable_partition_graph(size_t i) override;
593   void AddPartitionGraph(const GraphDef& partition_graph) override;
594   Status status() const override;
595   errors::Code status_code() const override;
596   const string& status_error_message() const override;
597   void set_status(const Status& status) override;
598 
599  protected:
600   RunGraphResponse* get_proto() override;
601 
602  private:
603   RunGraphResponse* const response_;
604 };
605 
606 ////////////////////////////////////////////////////////////////////////////////
607 //
608 // Wrapper classes for the `MasterService.RunStep` response message.
609 //
610 // The `RunStepResponse` message can contain potentially large tensor
611 // data as part of its `tensor` submessages. Here we provide specialized
612 // wrappers that avoid copying the tensor data wherever possible.
613 //
614 // See `RunStepResponse` in tensorflow/core/protobuf/master.proto for the
615 // protocol buffer definition.
616 //
617 ////////////////////////////////////////////////////////////////////////////////
618 
619 // Abstract interface for a mutable RunStepResponse message.
620 //
621 // Note that there is no corresponding (immutable)
622 // RunStepResponseWrapper class, because the RunStepResponse object is
623 // always used as a mutable pointer.
624 class MutableRunStepResponseWrapper {
625  public:
626   virtual ~MutableRunStepResponseWrapper();
627 
628   // The values of the tensors whose fetching was requested in the
629   // RunStep call.
630   //
631   // NOTE: The order of the returned tensors may or may not match
632   // the fetch order specified in RunStepRequest.
633   virtual size_t num_tensors() const = 0;
634   virtual const string& tensor_name(size_t i) const = 0;
635   virtual Status TensorValue(size_t i, Tensor* out_tensor) const = 0;
636 
637   // Stores the i^{th} recv value in `run_graph_response` in this
638   // response with the given `name`.
639   virtual Status AddTensorFromRunGraphResponse(
640       const string& name, MutableRunGraphResponseWrapper* run_graph_response,
641       size_t i) = 0;
642 
643   // Returned metadata if requested in the options.
644   virtual const RunMetadata& metadata() const = 0;
645   virtual RunMetadata* mutable_metadata() = 0;
646 
647   // Returned status if requested.
648   virtual Status status() const = 0;
649   virtual errors::Code status_code() const = 0;
650   virtual const string& status_error_message() const = 0;
651   virtual void set_status(const Status& status) = 0;
652 
653  protected:
654   // Returns a mutable protobuf message that represents the contents of
655   // this wrapper, for passing to an RPC subsystem that will populate
656   // the message.
657   //
658   // NOTE: Only `MasterInterface` subclasses may call this method. The
659   // `InMemoryRunStepResponse` subclass does not implement this
660   // method, and attempts to call it will fail with a fatal
661   // error. However, as long as callers always call
662   // `MasterInterface::RunStep()` with a wrapper object returned
663   // from `MasterInterface::CreateRunStepResponse()` called on the
664   // *same* MasterInterface object, this error will never trigger.
665   virtual RunStepResponse* get_proto() = 0;
666   friend class MasterInterface;
667 };
668 
669 class InMemoryRunStepResponse : public MutableRunStepResponseWrapper {
670  public:
671   // MutableRunStepResponseWrapper methods.
672   size_t num_tensors() const override;
673   const string& tensor_name(size_t i) const override;
674   Status TensorValue(size_t i, Tensor* out_tensor) const override;
675   Status AddTensorFromRunGraphResponse(
676       const string& name, MutableRunGraphResponseWrapper* run_graph_response,
677       size_t i) override;
678   const RunMetadata& metadata() const override;
679   RunMetadata* mutable_metadata() override;
680   Status status() const override;
681   errors::Code status_code() const override;
682   const string& status_error_message() const override;
683   void set_status(const Status& status) override;
684 
685  protected:
686   // NOTE: This method is not implemented. See
687   // MutableRunGraphResponseWrapper for an explanation.
688   RunStepResponse* get_proto() override;
689 
690  private:
691   gtl::InlinedVector<std::pair<string, Tensor>, 4> tensors_;
692   RunMetadata metadata_;
693   // Store the code and message separately so that they can be updated
694   // independently by setters.
695   Status status_;
696 };
697 
698 // Proto-based message wrapper for use on the client side of the RunStep RPC.
699 class OwnedProtoRunStepResponse : public MutableRunStepResponseWrapper {
700  public:
701   // MutableRunStepResponseWrapper methods.
702   size_t num_tensors() const override;
703   const string& tensor_name(size_t i) const override;
704   Status TensorValue(size_t i, Tensor* out_tensor) const override;
705   Status AddTensorFromRunGraphResponse(
706       const string& name, MutableRunGraphResponseWrapper* run_graph_response,
707       size_t i) override;
708   const RunMetadata& metadata() const override;
709   RunMetadata* mutable_metadata() override;
710   Status status() const override;
711   errors::Code status_code() const override;
712   const string& status_error_message() const override;
713   void set_status(const Status& status) override;
714 
715  protected:
716   RunStepResponse* get_proto() override;
717 
718  private:
719   RunStepResponse response_;
720 };
721 
722 // Proto-based message wrapper for use on the server side of the RunStep RPC.
723 class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper {
724  public:
725   NonOwnedProtoRunStepResponse(RunStepResponse* response);
726 
727   // MutableRunStepResponseWrapper methods.
728   size_t num_tensors() const override;
729   const string& tensor_name(size_t i) const override;
730   Status TensorValue(size_t i, Tensor* out_tensor) const override;
731   Status AddTensorFromRunGraphResponse(
732       const string& name, MutableRunGraphResponseWrapper* run_graph_response,
733       size_t i) override;
734   const RunMetadata& metadata() const override;
735   RunMetadata* mutable_metadata() override;
736   Status status() const override;
737   errors::Code status_code() const override;
738   const string& status_error_message() const override;
739   void set_status(const Status& status) override;
740 
741  protected:
742   RunStepResponse* get_proto() override;
743 
744  private:
745   RunStepResponse* response_;  // Not owned.
746 };
747 
748 bool ParseTensorProtoToTensor(const TensorProto& tensor_proto,
749                               Tensor* out_tensor);
750 
751 }  // namespace tensorflow
752 
753 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
754