1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_NODE_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_NODE_H_ 17 18 // clang-format off 19 // Required for IS_MOBILE_PLATFORM 20 #include <cstddef> 21 #include <memory> 22 #include <string> 23 #include "absl/container/flat_hash_map.h" 24 #include "tensorflow/core/platform/errors.h" 25 #include "tensorflow/core/platform/platform.h" 26 // clang-format on 27 28 #include "absl/container/inlined_vector.h" 29 #include "absl/memory/memory.h" 30 #include "absl/types/optional.h" 31 #include "absl/types/span.h" 32 #include "tensorflow/core/common_runtime/device.h" 33 #include "tensorflow/core/common_runtime/eager/context.h" 34 #include "tensorflow/core/common_runtime/eager/eager_executor.h" 35 #include "tensorflow/core/common_runtime/eager/execute.h" 36 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" 37 #include "tensorflow/core/common_runtime/eager/tensor_handle.h" 38 #include "tensorflow/core/framework/step_stats.pb.h" 39 #include "tensorflow/core/framework/tensor.h" 40 #include "tensorflow/core/framework/types.h" 41 #include "tensorflow/core/lib/core/status.h" 42 #include "tensorflow/core/lib/strings/strcat.h" 43 #if !defined(IS_MOBILE_PLATFORM) 44 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h" 45 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" 46 #endif // IS_MOBILE_PLATFORM 47 48 namespace tensorflow { 49 50 class ExecuteNodeArgs : public EagerKernelArgs { 51 public: ExecuteNodeArgs(int count)52 explicit ExecuteNodeArgs(int count) : EagerKernelArgs(count) {} 53 54 Status Init(EagerContext* ctx, 55 const absl::InlinedVector<TensorHandle*, 4>& op_inputs, 56 const core::RefCountPtr<KernelAndDevice>& kernel); 57 58 Status GetLocalArg(const FunctionArgIndex& index, Tensor* val) const override; 59 HasRemoteOrPackedInputs()60 bool HasRemoteOrPackedInputs() const override { 61 return has_remote_inputs_ || has_packed_inputs_; 62 }; 63 64 #if !defined(IS_MOBILE_PLATFORM) GetRemoteArg(const FunctionArgIndex & index,eager::RemoteTensorHandle * val)65 Status GetRemoteArg(const FunctionArgIndex& index, 66 eager::RemoteTensorHandle* val) const override { 67 return serialize_remote_handle_(index, val); 68 } 69 #endif // IS_MOBILE_PLATFORM 70 71 private: 72 #if !defined(IS_MOBILE_PLATFORM) 73 // Returns whether `handle` is a remote handle or has a remote mirror on 74 // `input_device` 75 bool IsRemote(EagerContext* ctx, Device* input_device, TensorHandle* handle); 76 #endif // IS_MOBILE_PLATFORM 77 78 // Initialize a packed TensorHandle which is the `index`-th argument. 79 Status InitPackedHandle(const int index, EagerContext* ctx, 80 Device* input_device, TensorHandle* packed_handle); 81 82 bool has_remote_inputs_ = false; 83 bool has_packed_inputs_ = false; 84 // Maps from the index of a packed arg to a list of sub-args. 85 absl::flat_hash_map<int, gtl::InlinedVector<TensorValue, 4>> packed_args_; 86 #if !defined(IS_MOBILE_PLATFORM) 87 std::function<Status(const FunctionArgIndex&, eager::RemoteTensorHandle*)> 88 serialize_remote_handle_; 89 #endif // IS_MOBILE_PLATFORM 90 }; 91 92 class ExecuteNode : public EagerNode { 93 public: ExecuteNode(EagerContext * ctx,const absl::InlinedVector<TensorHandle *,4> & inputs,const absl::optional<EagerFunctionParams> & eager_func_params,const core::RefCountPtr<KernelAndDevice> & kernel,GraphCollector * graph_collector,CancellationManager * cancellation_manager,absl::Span<TensorHandle * > retvals,absl::optional<ManagedStackTrace> stack_trace)94 ExecuteNode(EagerContext* ctx, 95 const absl::InlinedVector<TensorHandle*, 4>& inputs, 96 const absl::optional<EagerFunctionParams>& eager_func_params, 97 const core::RefCountPtr<KernelAndDevice>& kernel, 98 GraphCollector* graph_collector, 99 CancellationManager* cancellation_manager, 100 absl::Span<TensorHandle*> retvals, 101 absl::optional<ManagedStackTrace> stack_trace) 102 : EagerNode(), 103 ctx_(ctx), 104 inputs_(inputs), 105 eager_func_params_(eager_func_params), 106 kernel_(kernel), 107 graph_collector_(graph_collector), 108 cancellation_manager_(cancellation_manager), 109 retvals_(retvals), 110 stack_trace_(stack_trace) {} 111 Run()112 Status Run() override { 113 int i = 0; 114 for (TensorHandle* h : inputs_) { 115 if (h->RefCountIsOne()) { 116 const Device* d = ctx_->CanonicalDevice(kernel_->InputDevice(i)); 117 Status s = h->Unprotect(d); 118 if (!s.ok()) { 119 VLOG(1) << "Unable to unprotect tensor: " << s; 120 } 121 } 122 ++i; 123 } 124 return EagerKernelExecute(ctx_, inputs_, eager_func_params_, kernel_, 125 graph_collector_, cancellation_manager_, retvals_, 126 stack_trace_); 127 } 128 Abort(Status status)129 void Abort(Status status) override {} 130 DebugString()131 std::string DebugString() const override { 132 std::string out = "[ExecuteNode]"; 133 strings::StrAppend(&out, " kernel: ", kernel_->name()); 134 return out; 135 } 136 137 private: 138 EagerContext* ctx_; 139 const absl::InlinedVector<TensorHandle*, 4>& inputs_; 140 const absl::optional<EagerFunctionParams>& eager_func_params_; 141 const core::RefCountPtr<KernelAndDevice>& kernel_; 142 GraphCollector* graph_collector_; 143 CancellationManager* const cancellation_manager_; 144 absl::Span<TensorHandle*> retvals_; 145 absl::optional<ManagedStackTrace> stack_trace_; 146 }; 147 148 class AsyncExecuteNode : public EagerNode { 149 public: AsyncExecuteNode(EagerContext * ctx,const absl::InlinedVector<TensorHandle *,4> & inputs,const absl::optional<EagerFunctionParams> & eager_func_params,core::RefCountPtr<KernelAndDevice> kernel,GraphCollector * graph_collector,CancellationManager * cancellation_manager,absl::Span<TensorHandle * > retvals,absl::optional<ManagedStackTrace> stack_trace)150 AsyncExecuteNode(EagerContext* ctx, 151 const absl::InlinedVector<TensorHandle*, 4>& inputs, 152 const absl::optional<EagerFunctionParams>& eager_func_params, 153 core::RefCountPtr<KernelAndDevice> kernel, 154 GraphCollector* graph_collector, 155 CancellationManager* cancellation_manager, 156 absl::Span<TensorHandle*> retvals, 157 absl::optional<ManagedStackTrace> stack_trace) 158 : EagerNode(), 159 ctx_(ctx), 160 inputs_(inputs), 161 eager_func_params_(eager_func_params), 162 kernel_(std::move(kernel)), 163 graph_collector_(graph_collector), 164 cancellation_manager_(cancellation_manager), 165 stack_trace_(stack_trace) { 166 // Copy the output handles, since the container for them might get 167 // destroyed. 168 for (auto handle : retvals) { 169 handle->Ref(); 170 retvals_.push_back(handle); 171 } 172 173 // This is required to ensure that the tensor handles stay alive across 174 // the execution. 175 for (auto handle : inputs_) { 176 handle->Ref(); 177 } 178 } 179 ~AsyncExecuteNode()180 ~AsyncExecuteNode() override { 181 for (auto handle : retvals_) { 182 handle->Unref(); 183 } 184 185 for (auto handle : inputs_) { 186 handle->Unref(); 187 } 188 } 189 Run()190 Status Run() override { 191 int i = 0; 192 for (TensorHandle* h : inputs_) { 193 if (h->RefCountIsOne()) { 194 const Device* d = ctx_->CanonicalDevice(kernel_->InputDevice(i)); 195 Status s = h->Unprotect(d); 196 if (!s.ok()) { 197 VLOG(1) << "Unable to unprotect tensor: " << s; 198 } 199 } 200 ++i; 201 } 202 Status status = EagerKernelExecute( 203 ctx_, inputs_, eager_func_params_, kernel_, graph_collector_, 204 cancellation_manager_, absl::MakeSpan(retvals_), stack_trace_); 205 if (!status.ok()) { 206 if (stack_trace_.has_value()) { 207 errors::SetStackTrace(status, stack_trace_->ToStackFrames({}, {})); 208 } 209 Abort(status); 210 return status; 211 } 212 // If status is ok, EagerKernelExecute would have called SetTensor on 213 // all the output handles. 214 return OkStatus(); 215 } 216 Abort(Status status)217 void Abort(Status status) override { 218 int i = 0; 219 for (auto handle : retvals_) { 220 handle->Poison(status, ctx_->CanonicalDevice(kernel_->OutputDevice(i))); 221 ++i; 222 } 223 } 224 DebugString()225 std::string DebugString() const override { 226 std::string out = "[AsyncExecuteNode]"; 227 strings::StrAppend(&out, " kernel: ", kernel_->name()); 228 return out; 229 } 230 231 private: 232 EagerContext* ctx_; 233 absl::InlinedVector<TensorHandle*, 4> inputs_; 234 const absl::optional<EagerFunctionParams> eager_func_params_; 235 core::RefCountPtr<KernelAndDevice> kernel_; 236 GraphCollector* graph_collector_; 237 CancellationManager* const cancellation_manager_; 238 absl::optional<ManagedStackTrace> stack_trace_; 239 absl::InlinedVector<TensorHandle*, 2> retvals_; 240 }; 241 242 } // namespace tensorflow 243 244 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_NODE_H_ 245