1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_ 18 19 #include <cstddef> 20 #include <memory> 21 #include <utility> 22 23 #include "absl/types/span.h" 24 #include "tensorflow/core/common_runtime/eager/eager_executor.h" 25 #include "tensorflow/core/common_runtime/eager/shape_inference.h" 26 #include "tensorflow/core/common_runtime/eager/tensor_handle.h" 27 #include "tensorflow/core/distributed_runtime/eager/eager_client.h" 28 #include "tensorflow/core/framework/cancellation.h" 29 #include "tensorflow/core/framework/function.h" 30 #include "tensorflow/core/framework/node_def.pb.h" 31 #include "tensorflow/core/lib/gtl/inlined_vector.h" 32 #include "tensorflow/core/protobuf/eager_service.pb.h" 33 34 namespace tensorflow { 35 namespace eager { 36 37 // RemoteExecuteNode is an implementation of EagerNode which enqueues 38 // an operation via RPC in a remote EagerService. 39 class RemoteExecuteNode : public AsyncRemoteExecuteNode { 40 public: RemoteExecuteNode(EagerContext * eager_context,std::unique_ptr<EnqueueRequest> request,Device * device,uint64 context_view_id,EagerClient * eager_client,CancellationManager * cancellation_manager,const NodeDef & ndef,FunctionLibraryDefinition * lib_def,const gtl::InlinedVector<TensorHandle *,4> & inputs,absl::Span<TensorHandle * > retvals)41 RemoteExecuteNode(EagerContext* eager_context, 42 std::unique_ptr<EnqueueRequest> request, Device* device, 43 uint64 context_view_id, EagerClient* eager_client, 44 CancellationManager* cancellation_manager, 45 const NodeDef& ndef, FunctionLibraryDefinition* lib_def, 46 const gtl::InlinedVector<TensorHandle*, 4>& inputs, 47 absl::Span<TensorHandle*> retvals) 48 : AsyncRemoteExecuteNode(), 49 eager_context_(eager_context), 50 request_(std::move(request)), 51 device_(device), 52 context_view_id_(context_view_id), 53 eager_client_(eager_client), 54 cancellation_manager_(cancellation_manager), 55 ndef_(ndef), 56 lib_def_(lib_def), 57 inputs_(inputs) { 58 // Copy the output handles, since the container for them might get 59 // destroyed. 60 for (auto handle : retvals) { 61 handle->Ref(); 62 retvals_.push_back(handle); 63 } 64 65 // This is required to ensure that the tensor handles stay alive across the 66 // execution. 67 for (auto handle : inputs_) { 68 handle->Ref(); 69 } 70 eager_client_->Ref(); 71 72 needs_remote_inputs_ = false; 73 for (const TensorHandle* input : inputs_) { 74 // TODO(bramandia): Should this be op_device() instead? 75 if (input->resource_device() != nullptr && 76 input->resource_device() != device_) { 77 needs_remote_inputs_ = true; 78 break; 79 } 80 } 81 } 82 ~RemoteExecuteNode()83 ~RemoteExecuteNode() override { 84 for (auto handle : retvals_) { 85 handle->Unref(); 86 } 87 88 for (auto handle : inputs_) { 89 handle->Unref(); 90 } 91 eager_client_->Unref(); 92 } 93 Prepare()94 Status Prepare() override { 95 return RunShapeInference(ndef_, *lib_def_, inputs_, retvals_); 96 } 97 98 void RunAsync(StatusCallback done) override; 99 SyncExecutors()100 Status SyncExecutors() override { return eager_context_->SyncExecutors(); } 101 Abort(Status status)102 void Abort(Status status) override { 103 int i = 0; 104 for (auto handle : retvals_) { 105 handle->PoisonRemote(status, device_, context_view_id_); 106 ++i; 107 } 108 } 109 eager_client()110 const EagerClient* eager_client() const override { return eager_client_; } 111 needs_remote_inputs()112 bool needs_remote_inputs() const override { return needs_remote_inputs_; } 113 allow_multiple_pending_requests()114 bool allow_multiple_pending_requests() const override { 115 return eager_client_->allow_multiple_pending_requests(); 116 } 117 DebugString()118 string DebugString() const override { 119 string out = "[RemoteExecuteNode]"; 120 strings::StrAppend(&out, " request: ", request_->DebugString()); 121 strings::StrAppend(&out, ", target_device: ", device_->name()); 122 return out; 123 } 124 125 private: 126 EagerContext* eager_context_; // Not owned, and must outlive this node. 127 std::unique_ptr<EnqueueRequest> request_; 128 Device* device_; // Not owned 129 uint64 context_view_id_; 130 bool needs_remote_inputs_; 131 EagerClient* eager_client_; // Not owned, and must outlive this node. 132 CancellationManager* cancellation_manager_; 133 const NodeDef ndef_; 134 const FunctionLibraryDefinition* lib_def_; 135 gtl::InlinedVector<TensorHandle*, 4> inputs_; 136 gtl::InlinedVector<TensorHandle*, 2> retvals_; 137 }; 138 139 } // namespace eager 140 } // namespace tensorflow 141 142 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_ 143