1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_ 18 19 #include "tensorflow/core/framework/allocator.h" 20 #include "tensorflow/core/framework/cost_graph.pb.h" 21 #include "tensorflow/core/framework/graph.pb.h" 22 #include "tensorflow/core/framework/step_stats.pb.h" 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/framework/tensor.pb.h" 25 #include "tensorflow/core/framework/versions.pb.h" 26 #include "tensorflow/core/protobuf/config.pb.h" 27 #include "tensorflow/core/protobuf/master.pb.h" 28 #include "tensorflow/core/protobuf/worker.pb.h" 29 30 namespace tensorflow { 31 32 //////////////////////////////////////////////////////////////////////////////// 33 // 34 // Wrapper classes for the `MasterService.RunStep` request message. 35 // 36 // The `RunStepRequest` message can contain potentially large tensor 37 // data as part of its `feed` submessages. Here we provide specialized 38 // wrappers that avoid copying the tensor data wherever possible. 39 // 40 // See `RunStepRequest` in tensorflow/core/protobuf/master.proto for the 41 // protocol buffer definition. 42 // 43 //////////////////////////////////////////////////////////////////////////////// 44 45 // Abstract interface for an immutable RunStepRequest message. 46 // 47 // This interface is typically used by server-side components in the 48 // TensorFlow master. 49 class RunStepRequestWrapper { 50 public: ~RunStepRequestWrapper()51 virtual ~RunStepRequestWrapper() {} 52 53 // REQUIRED: session_handle must be returned by a CreateSession call 54 // to the same master service. 55 virtual const string& session_handle() const = 0; 56 57 // Partial run handle (optional). If specified, this will be a partial run 58 // execution, run up to the specified fetches. 59 virtual const string& partial_run_handle() const = 0; 60 61 // Tensors to be fed in the step. Each feed is a named tensor. 62 virtual size_t num_feeds() const = 0; 63 virtual const string& feed_name(size_t i) const = 0; 64 65 // Stores the content of the feed value at index `i` in `tensor`. 66 virtual Status FeedValue(size_t i, Tensor* out_tensor) const = 0; 67 virtual Status FeedValue(size_t i, TensorProto* out_tensor) const = 0; 68 69 // Fetches. A list of tensor names. The caller expects a tensor to 70 // be returned for each fetch[i] (see RunStepResponse.tensor). The 71 // order of specified fetches does not change the execution order. 72 virtual size_t num_fetches() const = 0; 73 virtual const string& fetch_name(size_t i) const = 0; 74 75 // Target Nodes. A list of node names. The named nodes will be run 76 // to but their outputs will not be fetched. 77 virtual size_t num_targets() const = 0; 78 virtual const string& target_name(size_t i) const = 0; 79 80 // Options for the run call. 81 virtual const RunOptions& options() const = 0; 82 83 // If true then some errors, e.g., execution errors that have long 84 // error messages, may return an OK RunStepResponse with the actual 85 // error saved in the status_code/status_error_message fields of the 86 // response body. This is a workaround since the RPC subsystem may 87 // truncate long metadata messages. 88 virtual bool store_errors_in_response_body() const = 0; 89 90 // Unique identifier for this request. Every RunGraphRequest must have a 91 // unique request_id, and retried RunGraphRequests must have the same 92 // request_id. If request_id is zero, retry detection is disabled. 93 virtual int64_t request_id() const = 0; 94 95 // Returns a human-readable representation of this message for debugging. 96 virtual string DebugString() const = 0; 97 98 // Returns the wrapped data as a protocol buffer message. 99 virtual const RunStepRequest& ToProto() const = 0; 100 }; 101 102 // Abstract interface for a mutable RunStepRequest message. 103 // 104 // See `RunStepRequestWrapper` above for a description of the fields. 105 class MutableRunStepRequestWrapper : public RunStepRequestWrapper { 106 public: 107 virtual void set_session_handle(const string& handle) = 0; 108 virtual void set_partial_run_handle(const string& handle) = 0; 109 virtual void add_feed(const string& name, const Tensor& value) = 0; 110 virtual void add_fetch(const string& name) = 0; 111 virtual void add_target(const string& name) = 0; 112 virtual RunOptions* mutable_options() = 0; 113 virtual void set_store_errors_in_response_body(bool store_errors) = 0; 114 }; 115 116 // Specialized (and mutable) wrapper for RunStep requests between a client and 117 // master in the same address space. 118 class InMemoryRunStepRequest : public MutableRunStepRequestWrapper { 119 public: 120 // RunStepRequestWrapper methods. 121 const string& session_handle() const override; 122 const string& partial_run_handle() const override; 123 size_t num_feeds() const override; 124 const string& feed_name(size_t i) const override; 125 Status FeedValue(size_t i, Tensor* out_tensor) const override; 126 Status FeedValue(size_t i, TensorProto* out_tensor) const override; 127 size_t num_fetches() const override; 128 const string& fetch_name(size_t i) const override; 129 size_t num_targets() const override; 130 const string& target_name(size_t i) const override; 131 const RunOptions& options() const override; 132 string DebugString() const override; 133 const RunStepRequest& ToProto() const override; 134 bool store_errors_in_response_body() const override; 135 int64_t request_id() const override; 136 137 // MutableRunStepRequestWrapper methods. 138 void set_session_handle(const string& handle) override; 139 void set_partial_run_handle(const string& handle) override; 140 void add_feed(const string& name, const Tensor& value) override; 141 void add_fetch(const string& name) override; 142 void add_target(const string& name) override; 143 RunOptions* mutable_options() override; 144 void set_store_errors_in_response_body(bool store_errors) override; 145 146 private: 147 string session_handle_; 148 string partial_run_handle_; 149 gtl::InlinedVector<std::pair<string, Tensor>, 4> feeds_; 150 gtl::InlinedVector<string, 4> fetches_; 151 gtl::InlinedVector<string, 4> targets_; 152 RunOptions options_; 153 bool store_errors_in_response_body_ = false; 154 155 // Holds a cached and owned representation of the proto 156 // representation of this request, if needed, so that `ToProto()` 157 // can return a const RunStepRequest&. 158 // NOTE(mrry): Although calls to `ToProto()` on this class are 159 // expected to be rare, retaining ownership of the returned message 160 // makes it easier to return a reference from the proto-backed 161 // representations. 162 mutable std::unique_ptr<RunStepRequest> proto_version_; 163 }; 164 165 // Wrapper for mutable RunStep requests that uses a protobuf message. 166 // 167 // This wrapper class should be used for RunStep requests between a 168 // client and master in different address spaces. 169 class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper { 170 public: 171 // RunStepRequestWrapper methods. 172 const string& session_handle() const override; 173 const string& partial_run_handle() const override; 174 size_t num_feeds() const override; 175 const string& feed_name(size_t i) const override; 176 Status FeedValue(size_t i, Tensor* out_tensor) const override; 177 Status FeedValue(size_t i, TensorProto* out_tensor) const override; 178 size_t num_fetches() const override; 179 const string& fetch_name(size_t i) const override; 180 size_t num_targets() const override; 181 const string& target_name(size_t i) const override; 182 const RunOptions& options() const override; 183 string DebugString() const override; 184 const RunStepRequest& ToProto() const override; 185 bool store_errors_in_response_body() const override; 186 int64_t request_id() const override; 187 188 // MutableRunStepRequestWrapper methods. 189 void set_session_handle(const string& handle) override; 190 void set_partial_run_handle(const string& handle) override; 191 void add_feed(const string& name, const Tensor& value) override; 192 void add_fetch(const string& name) override; 193 void add_target(const string& name) override; 194 RunOptions* mutable_options() override; 195 void set_store_errors_in_response_body(bool store_errors) override; 196 197 private: 198 RunStepRequest request_; 199 friend class MasterInterface; 200 }; 201 202 // Wrapper for immutable RunStep requests that use a non-owned 203 // protobuf message. 204 // 205 // This interface is typically used by server-side components in the 206 // TensorFlow master, where the incoming message is a (possibly const) 207 // `RunStepRequest*`. 208 class ProtoRunStepRequest : public RunStepRequestWrapper { 209 public: 210 ProtoRunStepRequest(const RunStepRequest* request); 211 212 // RunStepRequestWrapper methods. 213 const string& session_handle() const override; 214 const string& partial_run_handle() const override; 215 size_t num_feeds() const override; 216 const string& feed_name(size_t i) const override; 217 Status FeedValue(size_t i, Tensor* out_tensor) const override; 218 Status FeedValue(size_t i, TensorProto* out_tensor) const override; 219 size_t num_fetches() const override; 220 const string& fetch_name(size_t i) const override; 221 size_t num_targets() const override; 222 const string& target_name(size_t i) const override; 223 const RunOptions& options() const override; 224 string DebugString() const override; 225 const RunStepRequest& ToProto() const override; 226 bool store_errors_in_response_body() const override; 227 int64_t request_id() const override; 228 229 private: 230 const RunStepRequest* const request_; // Not owned. 231 }; 232 233 //////////////////////////////////////////////////////////////////////////////// 234 // 235 // Wrapper classes for the `WorkerService.RunGraph` request message. 236 // 237 // The `RunGraphRequest` message can contain potentially large tensor 238 // data as part of its `send` submessages. Here we provide specialized 239 // wrappers that avoid copying the tensor data wherever possible. 240 // 241 // See `RunGraphRequest` in tensorflow/core/protobuf/worker.proto for the 242 // protocol buffer definition. 243 // 244 //////////////////////////////////////////////////////////////////////////////// 245 246 // Abstract interface for an immutable RunGraphRequest message. 247 // 248 // This interface is typically used by server-side components in the 249 // TensorFlow worker. 250 class RunGraphRequestWrapper { 251 public: ~RunGraphRequestWrapper()252 virtual ~RunGraphRequestWrapper() {} 253 254 // The session handle used to register the graph. If empty, a single global 255 // namespace is used. 256 virtual const string& session_handle() const = 0; 257 258 // Set to true if `CreateWorkerSession` was called for `session_handle`. 259 virtual bool create_worker_session_called() const = 0; 260 261 // REQUIRED: graph_handle must be returned by a RegisterGraph call 262 // to the same WorkerService. 263 virtual const string& graph_handle() const = 0; 264 265 // A unique ID to distinguish different runs of the same graph. 266 // 267 // The master generates a global unique `step_id` to distinguish 268 // different runs of the graph computation. Subgraphs communicate 269 // (e.g., send/recv ops) with each other using `step_id` to 270 // distinguish tensors generated by different runs. 271 virtual int64_t step_id() const = 0; 272 273 // Options for this step. 274 virtual const ExecutorOpts& exec_opts() const = 0; 275 276 // Sends the tensors in "send" into the graph before the run. 277 virtual size_t num_sends() const = 0; 278 virtual const string& send_key(size_t i) const = 0; 279 virtual Status SendValue(size_t i, Tensor* out_tensor) const = 0; 280 281 // Fetches the keys into `RunGraphResponse.recv` after the run. 282 virtual size_t num_recvs() const = 0; 283 virtual const string& recv_key(size_t i) const = 0; 284 285 // True if the RunGraphRequest is a partial run request. 286 virtual bool is_partial() const = 0; 287 288 // True if this is the last partial run request in a sequence of requests. 289 virtual bool is_last_partial_run() const = 0; 290 291 // If true then some errors, e.g., execution errors that have long 292 // error messages, may return an OK RunStepResponse with the actual 293 // error saved in the status_code/status_error_message fields of the 294 // response body. This is a workaround since the RPC subsystem may 295 // truncate long metadata messages. 296 virtual bool store_errors_in_response_body() const = 0; 297 298 virtual int64_t request_id() const = 0; 299 300 // Returns the wrapped data as a protocol buffer message. 301 virtual const RunGraphRequest& ToProto() const = 0; 302 }; 303 304 // Abstract interface for a mutable RunGraphRequest message. 305 // 306 // See `RunGraphRequestWrapper` above for a description of the fields. 307 class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper { 308 public: 309 virtual void set_session_handle(const string& handle) = 0; 310 virtual void set_create_worker_session_called(bool called) = 0; 311 virtual void set_graph_handle(const string& handle) = 0; 312 virtual void set_step_id(int64_t step_id) = 0; 313 virtual ExecutorOpts* mutable_exec_opts() = 0; 314 315 // Stores the i^{th} feed value in `run_step_request` in this 316 // request with the given `send_key`. 317 virtual Status AddSendFromRunStepRequest( 318 const RunStepRequestWrapper& run_step_request, size_t i, 319 const string& send_key) = 0; 320 virtual Status AddSendFromRunCallableRequest( 321 const RunCallableRequest& run_callable_request, size_t i, 322 const string& send_key) = 0; 323 324 virtual void add_recv_key(const string& recv_key) = 0; 325 virtual void set_is_partial(bool is_partial) = 0; 326 virtual void set_is_last_partial_run(bool is_last_partial_run) = 0; 327 virtual void set_store_errors_in_response_body(bool store_errors) = 0; 328 virtual void set_request_id(int64_t request_id) = 0; 329 }; 330 331 class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { 332 public: 333 // RunGraphRequestWrapper methods. 334 const string& session_handle() const override; 335 const string& graph_handle() const override; 336 bool create_worker_session_called() const override; 337 int64_t step_id() const override; 338 const ExecutorOpts& exec_opts() const override; 339 size_t num_sends() const override; 340 const string& send_key(size_t i) const override; 341 Status SendValue(size_t i, Tensor* out_tensor) const override; 342 size_t num_recvs() const override; 343 const string& recv_key(size_t i) const override; 344 bool is_partial() const override; 345 bool is_last_partial_run() const override; 346 const RunGraphRequest& ToProto() const override; 347 bool store_errors_in_response_body() const override; 348 int64_t request_id() const override; 349 350 // MutableRunGraphRequestWrapper methods. 351 void set_session_handle(const string& handle) override; 352 void set_create_worker_session_called(bool called) override; 353 void set_graph_handle(const string& handle) override; 354 void set_step_id(int64_t step_id) override; 355 ExecutorOpts* mutable_exec_opts() override; 356 Status AddSendFromRunStepRequest( 357 const RunStepRequestWrapper& run_step_request, size_t i, 358 const string& send_key) override; 359 Status AddSendFromRunCallableRequest( 360 const RunCallableRequest& run_callable_request, size_t i, 361 const string& send_key) override; 362 void add_recv_key(const string& recv_key) override; 363 void set_is_partial(bool is_partial) override; 364 void set_is_last_partial_run(bool is_last_partial_run) override; 365 void set_store_errors_in_response_body(bool store_errors) override; 366 void set_request_id(int64_t request_id) override; 367 368 private: 369 string session_handle_; 370 bool create_worker_session_called_ = false; 371 string graph_handle_; 372 int64_t step_id_; 373 ExecutorOpts exec_opts_; 374 gtl::InlinedVector<std::pair<string, Tensor>, 4> sends_; 375 gtl::InlinedVector<string, 4> recvs_; 376 bool is_partial_ = false; 377 bool is_last_partial_run_ = false; 378 bool store_errors_in_response_body_ = false; 379 int64_t request_id_ = 0; 380 381 // Holds a cached and owned representation of the proto 382 // representation of this request, if needed, so that `ToProto()` 383 // can return a const RunGraphRequest&. 384 // NOTE(mrry): Although calls to `ToProto()` on this class are 385 // expected to be rare, retaining ownership of the returned message 386 // makes it easier to return a reference from the proto-backed 387 // representations. 388 mutable std::unique_ptr<RunGraphRequest> proto_version_; 389 }; 390 391 class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper { 392 public: 393 // RunGraphRequestWrapper methods. 394 const string& session_handle() const override; 395 bool create_worker_session_called() const override; 396 const string& graph_handle() const override; 397 int64_t step_id() const override; 398 const ExecutorOpts& exec_opts() const override; 399 size_t num_sends() const override; 400 const string& send_key(size_t i) const override; 401 Status SendValue(size_t i, Tensor* out_tensor) const override; 402 size_t num_recvs() const override; 403 const string& recv_key(size_t i) const override; 404 bool is_partial() const override; 405 bool is_last_partial_run() const override; 406 bool store_errors_in_response_body() const override; 407 int64_t request_id() const override; 408 const RunGraphRequest& ToProto() const override; 409 410 // MutableRunGraphRequestWrapper methods. 411 void set_session_handle(const string& handle) override; 412 void set_create_worker_session_called(bool called) override; 413 void set_graph_handle(const string& handle) override; 414 void set_step_id(int64_t step_id) override; 415 ExecutorOpts* mutable_exec_opts() override; 416 Status AddSendFromRunStepRequest( 417 const RunStepRequestWrapper& run_step_request, size_t i, 418 const string& send_key) override; 419 Status AddSendFromRunCallableRequest( 420 const RunCallableRequest& run_callable_request, size_t i, 421 const string& send_key) override; 422 void add_recv_key(const string& recv_key) override; 423 void set_is_partial(bool is_partial) override; 424 void set_is_last_partial_run(bool is_last_partial_run) override; 425 void set_store_errors_in_response_body(bool store_errors) override; 426 void set_request_id(int64_t request_id) override; 427 428 private: 429 RunGraphRequest request_; 430 }; 431 432 class ProtoRunGraphRequest : public RunGraphRequestWrapper { 433 public: 434 ProtoRunGraphRequest(const RunGraphRequest* request); 435 436 // RunGraphRequestWrapper methods. 437 const string& session_handle() const override; 438 bool create_worker_session_called() const override; 439 const string& graph_handle() const override; 440 int64_t step_id() const override; 441 const ExecutorOpts& exec_opts() const override; 442 size_t num_sends() const override; 443 const string& send_key(size_t i) const override; 444 Status SendValue(size_t i, Tensor* out_tensor) const override; 445 size_t num_recvs() const override; 446 const string& recv_key(size_t i) const override; 447 bool is_partial() const override; 448 bool is_last_partial_run() const override; 449 bool store_errors_in_response_body() const override; 450 int64_t request_id() const override; 451 const RunGraphRequest& ToProto() const override; 452 453 private: 454 const RunGraphRequest* const request_; // Not owned. 455 }; 456 457 //////////////////////////////////////////////////////////////////////////////// 458 // 459 // Wrapper classes for the `WorkerService.RunGraph` response message. 460 // 461 // The `RunGraphResponse` message can contain potentially large tensor 462 // data as part of its `recv` submessages. Here we provide specialized 463 // wrappers that avoid copying the tensor data wherever possible. 464 // 465 // See `RunGraphResponse` in tensorflow/core/protobuf/worker.proto for the 466 // protocol buffer definition. 467 // 468 //////////////////////////////////////////////////////////////////////////////// 469 470 // Abstract interface for a mutable RunGraphResponse message. 471 // 472 // Note that there is no corresponding (immutable) 473 // RunGraphResponseWrapper class, because the RunGraphResponse object 474 // is always used as a mutable pointer. 475 class MutableRunGraphResponseWrapper { 476 public: ~MutableRunGraphResponseWrapper()477 virtual ~MutableRunGraphResponseWrapper() {} 478 479 // A list of tensors corresponding to those requested by 480 // `RunGraphRequest.recv_key`. 481 virtual size_t num_recvs() const = 0; 482 virtual const string& recv_key(size_t i) const = 0; 483 // NOTE: The following methods may perform a destructive read, for 484 // efficiency. 485 virtual Status RecvValue(size_t i, TensorProto* out_tensor) = 0; 486 virtual Status RecvValue(size_t i, Tensor* out_tensor) = 0; 487 virtual void AddRecv(const string& key, const Tensor& value) = 0; 488 489 // Submessages that store performance statistics about the subgraph 490 // execution, if necessary. 491 virtual StepStats* mutable_step_stats() = 0; 492 virtual CostGraphDef* mutable_cost_graph() = 0; 493 virtual size_t num_partition_graphs() const = 0; 494 virtual GraphDef* mutable_partition_graph(size_t i) = 0; 495 virtual void AddPartitionGraph(const GraphDef& partition_graph) = 0; 496 497 // Returned status if requested. 498 virtual Status status() const = 0; 499 virtual errors::Code status_code() const = 0; 500 virtual const string& status_error_message() const = 0; 501 virtual void set_status(const Status& status) = 0; 502 503 protected: 504 // Returns a mutable protobuf message that represents the contents of 505 // this wrapper, for passing to an RPC subsystem that will populate 506 // the message. 507 // 508 // NOTE: Only `WorkerInterface` subclasses may call this method. The 509 // `InMemoryRunGraphResponse` subclass does not implement this 510 // method, and attempts to call it will fail with a fatal 511 // error. However, as long as callers always call 512 // `WorkerInterface::RunGraphAsync()` with a wrapper object returned 513 // from `WorkerInterface::CreateRunGraphResponse()` called on the 514 // *same* WorkerInterface object, this error will never trigger. 515 virtual RunGraphResponse* get_proto() = 0; 516 friend class WorkerInterface; 517 }; 518 519 class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper { 520 public: 521 // MutableRunGraphResponseWrapper methods. 522 size_t num_recvs() const override; 523 const string& recv_key(size_t i) const override; 524 Status RecvValue(size_t i, TensorProto* out_tensor) override; 525 Status RecvValue(size_t i, Tensor* out_tensor) override; 526 void AddRecv(const string& key, const Tensor& value) override; 527 StepStats* mutable_step_stats() override; 528 CostGraphDef* mutable_cost_graph() override; 529 size_t num_partition_graphs() const override; 530 GraphDef* mutable_partition_graph(size_t i) override; 531 void AddPartitionGraph(const GraphDef& partition_graph) override; 532 Status status() const override; 533 errors::Code status_code() const override; 534 const string& status_error_message() const override; 535 void set_status(const Status& status) override; 536 537 protected: 538 // NOTE: This method is not implemented. See 539 // MutableRunGraphResponseWrapper for an explanation. 540 RunGraphResponse* get_proto() override; 541 542 private: 543 gtl::InlinedVector<std::pair<string, Tensor>, 4> recvs_; 544 StepStats step_stats_; 545 CostGraphDef cost_graph_; 546 std::vector<GraphDef> partition_graphs_; 547 // Store the code and message separately so that they can be updated 548 // independently by setters. 549 Status status_; 550 }; 551 552 // Proto-based message wrapper for use on the client side of the RunGraph RPC. 553 class OwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper { 554 public: 555 // MutableRunGraphResponseWrapper methods. 556 size_t num_recvs() const override; 557 const string& recv_key(size_t i) const override; 558 Status RecvValue(size_t i, TensorProto* out_tensor) override; 559 Status RecvValue(size_t i, Tensor* out_tensor) override; 560 void AddRecv(const string& key, const Tensor& value) override; 561 StepStats* mutable_step_stats() override; 562 CostGraphDef* mutable_cost_graph() override; 563 size_t num_partition_graphs() const override; 564 GraphDef* mutable_partition_graph(size_t i) override; 565 void AddPartitionGraph(const GraphDef& partition_graph) override; 566 Status status() const override; 567 errors::Code status_code() const override; 568 const string& status_error_message() const override; 569 void set_status(const Status& status) override; 570 571 protected: 572 RunGraphResponse* get_proto() override; 573 574 private: 575 RunGraphResponse response_; 576 }; 577 578 // Proto-based message wrapper for use on the server side of the RunGraph RPC. 579 class NonOwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper { 580 public: 581 NonOwnedProtoRunGraphResponse(RunGraphResponse* response); 582 583 // MutableRunGraphResponseWrapper methods. 584 size_t num_recvs() const override; 585 const string& recv_key(size_t i) const override; 586 Status RecvValue(size_t i, TensorProto* out_tensor) override; 587 Status RecvValue(size_t i, Tensor* out_tensor) override; 588 void AddRecv(const string& key, const Tensor& value) override; 589 StepStats* mutable_step_stats() override; 590 CostGraphDef* mutable_cost_graph() override; 591 size_t num_partition_graphs() const override; 592 GraphDef* mutable_partition_graph(size_t i) override; 593 void AddPartitionGraph(const GraphDef& partition_graph) override; 594 Status status() const override; 595 errors::Code status_code() const override; 596 const string& status_error_message() const override; 597 void set_status(const Status& status) override; 598 599 protected: 600 RunGraphResponse* get_proto() override; 601 602 private: 603 RunGraphResponse* const response_; 604 }; 605 606 //////////////////////////////////////////////////////////////////////////////// 607 // 608 // Wrapper classes for the `MasterService.RunStep` response message. 609 // 610 // The `RunStepResponse` message can contain potentially large tensor 611 // data as part of its `tensor` submessages. Here we provide specialized 612 // wrappers that avoid copying the tensor data wherever possible. 613 // 614 // See `RunStepResponse` in tensorflow/core/protobuf/master.proto for the 615 // protocol buffer definition. 616 // 617 //////////////////////////////////////////////////////////////////////////////// 618 619 // Abstract interface for a mutable RunStepResponse message. 620 // 621 // Note that there is no corresponding (immutable) 622 // RunStepResponseWrapper class, because the RunStepResponse object is 623 // always used as a mutable pointer. 624 class MutableRunStepResponseWrapper { 625 public: 626 virtual ~MutableRunStepResponseWrapper(); 627 628 // The values of the tensors whose fetching was requested in the 629 // RunStep call. 630 // 631 // NOTE: The order of the returned tensors may or may not match 632 // the fetch order specified in RunStepRequest. 633 virtual size_t num_tensors() const = 0; 634 virtual const string& tensor_name(size_t i) const = 0; 635 virtual Status TensorValue(size_t i, Tensor* out_tensor) const = 0; 636 637 // Stores the i^{th} recv value in `run_graph_response` in this 638 // response with the given `name`. 639 virtual Status AddTensorFromRunGraphResponse( 640 const string& name, MutableRunGraphResponseWrapper* run_graph_response, 641 size_t i) = 0; 642 643 // Returned metadata if requested in the options. 644 virtual const RunMetadata& metadata() const = 0; 645 virtual RunMetadata* mutable_metadata() = 0; 646 647 // Returned status if requested. 648 virtual Status status() const = 0; 649 virtual errors::Code status_code() const = 0; 650 virtual const string& status_error_message() const = 0; 651 virtual void set_status(const Status& status) = 0; 652 653 protected: 654 // Returns a mutable protobuf message that represents the contents of 655 // this wrapper, for passing to an RPC subsystem that will populate 656 // the message. 657 // 658 // NOTE: Only `MasterInterface` subclasses may call this method. The 659 // `InMemoryRunStepResponse` subclass does not implement this 660 // method, and attempts to call it will fail with a fatal 661 // error. However, as long as callers always call 662 // `MasterInterface::RunStep()` with a wrapper object returned 663 // from `MasterInterface::CreateRunStepResponse()` called on the 664 // *same* MasterInterface object, this error will never trigger. 665 virtual RunStepResponse* get_proto() = 0; 666 friend class MasterInterface; 667 }; 668 669 class InMemoryRunStepResponse : public MutableRunStepResponseWrapper { 670 public: 671 // MutableRunStepResponseWrapper methods. 672 size_t num_tensors() const override; 673 const string& tensor_name(size_t i) const override; 674 Status TensorValue(size_t i, Tensor* out_tensor) const override; 675 Status AddTensorFromRunGraphResponse( 676 const string& name, MutableRunGraphResponseWrapper* run_graph_response, 677 size_t i) override; 678 const RunMetadata& metadata() const override; 679 RunMetadata* mutable_metadata() override; 680 Status status() const override; 681 errors::Code status_code() const override; 682 const string& status_error_message() const override; 683 void set_status(const Status& status) override; 684 685 protected: 686 // NOTE: This method is not implemented. See 687 // MutableRunGraphResponseWrapper for an explanation. 688 RunStepResponse* get_proto() override; 689 690 private: 691 gtl::InlinedVector<std::pair<string, Tensor>, 4> tensors_; 692 RunMetadata metadata_; 693 // Store the code and message separately so that they can be updated 694 // independently by setters. 695 Status status_; 696 }; 697 698 // Proto-based message wrapper for use on the client side of the RunStep RPC. 699 class OwnedProtoRunStepResponse : public MutableRunStepResponseWrapper { 700 public: 701 // MutableRunStepResponseWrapper methods. 702 size_t num_tensors() const override; 703 const string& tensor_name(size_t i) const override; 704 Status TensorValue(size_t i, Tensor* out_tensor) const override; 705 Status AddTensorFromRunGraphResponse( 706 const string& name, MutableRunGraphResponseWrapper* run_graph_response, 707 size_t i) override; 708 const RunMetadata& metadata() const override; 709 RunMetadata* mutable_metadata() override; 710 Status status() const override; 711 errors::Code status_code() const override; 712 const string& status_error_message() const override; 713 void set_status(const Status& status) override; 714 715 protected: 716 RunStepResponse* get_proto() override; 717 718 private: 719 RunStepResponse response_; 720 }; 721 722 // Proto-based message wrapper for use on the server side of the RunStep RPC. 723 class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper { 724 public: 725 NonOwnedProtoRunStepResponse(RunStepResponse* response); 726 727 // MutableRunStepResponseWrapper methods. 728 size_t num_tensors() const override; 729 const string& tensor_name(size_t i) const override; 730 Status TensorValue(size_t i, Tensor* out_tensor) const override; 731 Status AddTensorFromRunGraphResponse( 732 const string& name, MutableRunGraphResponseWrapper* run_graph_response, 733 size_t i) override; 734 const RunMetadata& metadata() const override; 735 RunMetadata* mutable_metadata() override; 736 Status status() const override; 737 errors::Code status_code() const override; 738 const string& status_error_message() const override; 739 void set_status(const Status& status) override; 740 741 protected: 742 RunStepResponse* get_proto() override; 743 744 private: 745 RunStepResponse* response_; // Not owned. 746 }; 747 748 bool ParseTensorProtoToTensor(const TensorProto& tensor_proto, 749 Tensor* out_tensor); 750 751 } // namespace tensorflow 752 753 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_ 754