xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/eager/remote_execute_node.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_
18 
19 #include <cstddef>
20 #include <memory>
21 #include <utility>
22 
23 #include "absl/types/span.h"
24 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
25 #include "tensorflow/core/common_runtime/eager/shape_inference.h"
26 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
27 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
28 #include "tensorflow/core/framework/cancellation.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/lib/gtl/inlined_vector.h"
32 #include "tensorflow/core/protobuf/eager_service.pb.h"
33 
34 namespace tensorflow {
35 namespace eager {
36 
37 // RemoteExecuteNode is an implementation of EagerNode which enqueues
38 // an operation via RPC in a remote EagerService.
39 class RemoteExecuteNode : public AsyncRemoteExecuteNode {
40  public:
RemoteExecuteNode(EagerContext * eager_context,std::unique_ptr<EnqueueRequest> request,Device * device,uint64 context_view_id,EagerClient * eager_client,CancellationManager * cancellation_manager,const NodeDef & ndef,FunctionLibraryDefinition * lib_def,const gtl::InlinedVector<TensorHandle *,4> & inputs,absl::Span<TensorHandle * > retvals)41   RemoteExecuteNode(EagerContext* eager_context,
42                     std::unique_ptr<EnqueueRequest> request, Device* device,
43                     uint64 context_view_id, EagerClient* eager_client,
44                     CancellationManager* cancellation_manager,
45                     const NodeDef& ndef, FunctionLibraryDefinition* lib_def,
46                     const gtl::InlinedVector<TensorHandle*, 4>& inputs,
47                     absl::Span<TensorHandle*> retvals)
48       : AsyncRemoteExecuteNode(),
49         eager_context_(eager_context),
50         request_(std::move(request)),
51         device_(device),
52         context_view_id_(context_view_id),
53         eager_client_(eager_client),
54         cancellation_manager_(cancellation_manager),
55         ndef_(ndef),
56         lib_def_(lib_def),
57         inputs_(inputs) {
58     // Copy the output handles, since the container for them might get
59     // destroyed.
60     for (auto handle : retvals) {
61       handle->Ref();
62       retvals_.push_back(handle);
63     }
64 
65     // This is required to ensure that the tensor handles stay alive across the
66     // execution.
67     for (auto handle : inputs_) {
68       handle->Ref();
69     }
70     eager_client_->Ref();
71 
72     needs_remote_inputs_ = false;
73     for (const TensorHandle* input : inputs_) {
74       // TODO(bramandia): Should this be op_device() instead?
75       if (input->resource_device() != nullptr &&
76           input->resource_device() != device_) {
77         needs_remote_inputs_ = true;
78         break;
79       }
80     }
81   }
82 
~RemoteExecuteNode()83   ~RemoteExecuteNode() override {
84     for (auto handle : retvals_) {
85       handle->Unref();
86     }
87 
88     for (auto handle : inputs_) {
89       handle->Unref();
90     }
91     eager_client_->Unref();
92   }
93 
Prepare()94   Status Prepare() override {
95     return RunShapeInference(ndef_, *lib_def_, inputs_, retvals_);
96   }
97 
98   void RunAsync(StatusCallback done) override;
99 
SyncExecutors()100   Status SyncExecutors() override { return eager_context_->SyncExecutors(); }
101 
Abort(Status status)102   void Abort(Status status) override {
103     int i = 0;
104     for (auto handle : retvals_) {
105       handle->PoisonRemote(status, device_, context_view_id_);
106       ++i;
107     }
108   }
109 
eager_client()110   const EagerClient* eager_client() const override { return eager_client_; }
111 
needs_remote_inputs()112   bool needs_remote_inputs() const override { return needs_remote_inputs_; }
113 
allow_multiple_pending_requests()114   bool allow_multiple_pending_requests() const override {
115     return eager_client_->allow_multiple_pending_requests();
116   }
117 
DebugString()118   string DebugString() const override {
119     string out = "[RemoteExecuteNode]";
120     strings::StrAppend(&out, " request: ", request_->DebugString());
121     strings::StrAppend(&out, ", target_device: ", device_->name());
122     return out;
123   }
124 
125  private:
126   EagerContext* eager_context_;  // Not owned, and must outlive this node.
127   std::unique_ptr<EnqueueRequest> request_;
128   Device* device_;             // Not owned
129   uint64 context_view_id_;
130   bool needs_remote_inputs_;
131   EagerClient* eager_client_;  // Not owned, and must outlive this node.
132   CancellationManager* cancellation_manager_;
133   const NodeDef ndef_;
134   const FunctionLibraryDefinition* lib_def_;
135   gtl::InlinedVector<TensorHandle*, 4> inputs_;
136   gtl::InlinedVector<TensorHandle*, 2> retvals_;
137 };
138 
139 }  // namespace eager
140 }  // namespace tensorflow
141 
142 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_
143