xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/eager/execute_node.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_NODE_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_NODE_H_
17 
18 // clang-format off
19 // Required for IS_MOBILE_PLATFORM
20 #include <cstddef>
21 #include <memory>
22 #include <string>
23 #include "absl/container/flat_hash_map.h"
24 #include "tensorflow/core/platform/errors.h"
25 #include "tensorflow/core/platform/platform.h"
26 // clang-format on
27 
28 #include "absl/container/inlined_vector.h"
29 #include "absl/memory/memory.h"
30 #include "absl/types/optional.h"
31 #include "absl/types/span.h"
32 #include "tensorflow/core/common_runtime/device.h"
33 #include "tensorflow/core/common_runtime/eager/context.h"
34 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
35 #include "tensorflow/core/common_runtime/eager/execute.h"
36 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
37 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
38 #include "tensorflow/core/framework/step_stats.pb.h"
39 #include "tensorflow/core/framework/tensor.h"
40 #include "tensorflow/core/framework/types.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/lib/strings/strcat.h"
43 #if !defined(IS_MOBILE_PLATFORM)
44 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
45 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
46 #endif  // IS_MOBILE_PLATFORM
47 
48 namespace tensorflow {
49 
50 class ExecuteNodeArgs : public EagerKernelArgs {
51  public:
ExecuteNodeArgs(int count)52   explicit ExecuteNodeArgs(int count) : EagerKernelArgs(count) {}
53 
54   Status Init(EagerContext* ctx,
55               const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
56               const core::RefCountPtr<KernelAndDevice>& kernel);
57 
58   Status GetLocalArg(const FunctionArgIndex& index, Tensor* val) const override;
59 
HasRemoteOrPackedInputs()60   bool HasRemoteOrPackedInputs() const override {
61     return has_remote_inputs_ || has_packed_inputs_;
62   };
63 
64 #if !defined(IS_MOBILE_PLATFORM)
GetRemoteArg(const FunctionArgIndex & index,eager::RemoteTensorHandle * val)65   Status GetRemoteArg(const FunctionArgIndex& index,
66                       eager::RemoteTensorHandle* val) const override {
67     return serialize_remote_handle_(index, val);
68   }
69 #endif  // IS_MOBILE_PLATFORM
70 
71  private:
72 #if !defined(IS_MOBILE_PLATFORM)
73   // Returns whether `handle` is a remote handle or has a remote mirror on
74   // `input_device`
75   bool IsRemote(EagerContext* ctx, Device* input_device, TensorHandle* handle);
76 #endif  // IS_MOBILE_PLATFORM
77 
78   // Initialize a packed TensorHandle which is the `index`-th argument.
79   Status InitPackedHandle(const int index, EagerContext* ctx,
80                           Device* input_device, TensorHandle* packed_handle);
81 
82   bool has_remote_inputs_ = false;
83   bool has_packed_inputs_ = false;
84   // Maps from the index of a packed arg to a list of sub-args.
85   absl::flat_hash_map<int, gtl::InlinedVector<TensorValue, 4>> packed_args_;
86 #if !defined(IS_MOBILE_PLATFORM)
87   std::function<Status(const FunctionArgIndex&, eager::RemoteTensorHandle*)>
88       serialize_remote_handle_;
89 #endif  // IS_MOBILE_PLATFORM
90 };
91 
92 class ExecuteNode : public EagerNode {
93  public:
ExecuteNode(EagerContext * ctx,const absl::InlinedVector<TensorHandle *,4> & inputs,const absl::optional<EagerFunctionParams> & eager_func_params,const core::RefCountPtr<KernelAndDevice> & kernel,GraphCollector * graph_collector,CancellationManager * cancellation_manager,absl::Span<TensorHandle * > retvals,absl::optional<ManagedStackTrace> stack_trace)94   ExecuteNode(EagerContext* ctx,
95               const absl::InlinedVector<TensorHandle*, 4>& inputs,
96               const absl::optional<EagerFunctionParams>& eager_func_params,
97               const core::RefCountPtr<KernelAndDevice>& kernel,
98               GraphCollector* graph_collector,
99               CancellationManager* cancellation_manager,
100               absl::Span<TensorHandle*> retvals,
101               absl::optional<ManagedStackTrace> stack_trace)
102       : EagerNode(),
103         ctx_(ctx),
104         inputs_(inputs),
105         eager_func_params_(eager_func_params),
106         kernel_(kernel),
107         graph_collector_(graph_collector),
108         cancellation_manager_(cancellation_manager),
109         retvals_(retvals),
110         stack_trace_(stack_trace) {}
111 
Run()112   Status Run() override {
113     int i = 0;
114     for (TensorHandle* h : inputs_) {
115       if (h->RefCountIsOne()) {
116         const Device* d = ctx_->CanonicalDevice(kernel_->InputDevice(i));
117         Status s = h->Unprotect(d);
118         if (!s.ok()) {
119           VLOG(1) << "Unable to unprotect tensor: " << s;
120         }
121       }
122       ++i;
123     }
124     return EagerKernelExecute(ctx_, inputs_, eager_func_params_, kernel_,
125                               graph_collector_, cancellation_manager_, retvals_,
126                               stack_trace_);
127   }
128 
Abort(Status status)129   void Abort(Status status) override {}
130 
DebugString()131   std::string DebugString() const override {
132     std::string out = "[ExecuteNode]";
133     strings::StrAppend(&out, " kernel: ", kernel_->name());
134     return out;
135   }
136 
137  private:
138   EagerContext* ctx_;
139   const absl::InlinedVector<TensorHandle*, 4>& inputs_;
140   const absl::optional<EagerFunctionParams>& eager_func_params_;
141   const core::RefCountPtr<KernelAndDevice>& kernel_;
142   GraphCollector* graph_collector_;
143   CancellationManager* const cancellation_manager_;
144   absl::Span<TensorHandle*> retvals_;
145   absl::optional<ManagedStackTrace> stack_trace_;
146 };
147 
148 class AsyncExecuteNode : public EagerNode {
149  public:
AsyncExecuteNode(EagerContext * ctx,const absl::InlinedVector<TensorHandle *,4> & inputs,const absl::optional<EagerFunctionParams> & eager_func_params,core::RefCountPtr<KernelAndDevice> kernel,GraphCollector * graph_collector,CancellationManager * cancellation_manager,absl::Span<TensorHandle * > retvals,absl::optional<ManagedStackTrace> stack_trace)150   AsyncExecuteNode(EagerContext* ctx,
151                    const absl::InlinedVector<TensorHandle*, 4>& inputs,
152                    const absl::optional<EagerFunctionParams>& eager_func_params,
153                    core::RefCountPtr<KernelAndDevice> kernel,
154                    GraphCollector* graph_collector,
155                    CancellationManager* cancellation_manager,
156                    absl::Span<TensorHandle*> retvals,
157                    absl::optional<ManagedStackTrace> stack_trace)
158       : EagerNode(),
159         ctx_(ctx),
160         inputs_(inputs),
161         eager_func_params_(eager_func_params),
162         kernel_(std::move(kernel)),
163         graph_collector_(graph_collector),
164         cancellation_manager_(cancellation_manager),
165         stack_trace_(stack_trace) {
166     // Copy the output handles, since the container for them might get
167     // destroyed.
168     for (auto handle : retvals) {
169       handle->Ref();
170       retvals_.push_back(handle);
171     }
172 
173     // This is required to ensure that the tensor handles stay alive across
174     // the execution.
175     for (auto handle : inputs_) {
176       handle->Ref();
177     }
178   }
179 
~AsyncExecuteNode()180   ~AsyncExecuteNode() override {
181     for (auto handle : retvals_) {
182       handle->Unref();
183     }
184 
185     for (auto handle : inputs_) {
186       handle->Unref();
187     }
188   }
189 
Run()190   Status Run() override {
191     int i = 0;
192     for (TensorHandle* h : inputs_) {
193       if (h->RefCountIsOne()) {
194         const Device* d = ctx_->CanonicalDevice(kernel_->InputDevice(i));
195         Status s = h->Unprotect(d);
196         if (!s.ok()) {
197           VLOG(1) << "Unable to unprotect tensor: " << s;
198         }
199       }
200       ++i;
201     }
202     Status status = EagerKernelExecute(
203         ctx_, inputs_, eager_func_params_, kernel_, graph_collector_,
204         cancellation_manager_, absl::MakeSpan(retvals_), stack_trace_);
205     if (!status.ok()) {
206       if (stack_trace_.has_value()) {
207         errors::SetStackTrace(status, stack_trace_->ToStackFrames({}, {}));
208       }
209       Abort(status);
210       return status;
211     }
212     // If status is ok, EagerKernelExecute would have called SetTensor on
213     // all the output handles.
214     return OkStatus();
215   }
216 
Abort(Status status)217   void Abort(Status status) override {
218     int i = 0;
219     for (auto handle : retvals_) {
220       handle->Poison(status, ctx_->CanonicalDevice(kernel_->OutputDevice(i)));
221       ++i;
222     }
223   }
224 
DebugString()225   std::string DebugString() const override {
226     std::string out = "[AsyncExecuteNode]";
227     strings::StrAppend(&out, " kernel: ", kernel_->name());
228     return out;
229   }
230 
231  private:
232   EagerContext* ctx_;
233   absl::InlinedVector<TensorHandle*, 4> inputs_;
234   const absl::optional<EagerFunctionParams> eager_func_params_;
235   core::RefCountPtr<KernelAndDevice> kernel_;
236   GraphCollector* graph_collector_;
237   CancellationManager* const cancellation_manager_;
238   absl::optional<ManagedStackTrace> stack_trace_;
239   absl::InlinedVector<TensorHandle*, 2> retvals_;
240 };
241 
242 }  // namespace tensorflow
243 
244 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_NODE_H_
245