xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/eager/remote_mgr.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_MGR_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_MGR_H_
18 
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
23 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
24 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
25 #include "tensorflow/core/platform/mutex.h"
26 
27 namespace tensorflow {
28 namespace eager {
29 
30 // This class manages the states required to setup an eager cluster.
31 // TODO(fishx): Move remote state from context to this class.
32 class RemoteMgr {
33  public:
RemoteMgr(bool is_master,EagerContext * ctx)34   RemoteMgr(bool is_master, EagerContext* ctx)
35       : is_master_(is_master), parent_(ctx) {}
36 
~RemoteMgr()37   ~RemoteMgr() {
38     for (const auto& entry : remote_tensor_handle_map_) {
39       entry.second->Unref();
40     }
41   }
42 
IsMaster()43   bool IsMaster() { return is_master_; }
44 
45   void AddOperationOutputs(
46       const gtl::ArraySlice<tensorflow::TensorHandle*> handles,
47       int64_t operation_id);
48 
49   void AddOperationOutput(tensorflow::TensorHandle* handles,
50                           int64_t operation_id, int32_t output_num);
51 
52   Status GetTensorHandle(const RemoteTensorHandleInternal& remote_handle,
53                          tensorflow::TensorHandle** handle);
54 
55   Status DeleteTensorHandle(const RemoteTensorHandleInternal& remote_handle);
56 
57   // Helper function to create monotonically increasing ids unique to this
58   // context.
NextOpId()59   uint64 NextOpId() {
60     DCHECK(is_master_);
61     mutex_lock l(next_id_mutex_);
62     return next_op_id_++;
63   }
64 
65   // Serialize a remote TensorHandle to a RemoteTensorHandle.
66   // If wait_until_ready is true, block until the remote handle is ready on a
67   // remote worker.
68   Status SerializeRemoteTensorHandle(
69       TensorHandle* in, const bool wait_until_ready, RemoteTensorHandle* out,
70       Device* device, const string& device_name,
71       const bool serialize_resource_dtype_and_shape = false);
72 
73   // Deserialize a RemoteTensorHandle to a TensorHandle(local/remote).
74   // The output holds a reference to the TensorHandle.
75   Status DeserializeRemoteTensorHandle(const RemoteTensorHandle& in,
76                                        TensorHandle** out);
77 
78   EagerExecutor& GetOrCreateExecutorForStream(uint64 stream_id);
79 
80   void DeleteExecutorForStream(uint64 stream_id);
81 
82  protected:
83   mutex next_id_mutex_;
84   uint64 next_op_id_ TF_GUARDED_BY(next_id_mutex_) = 1;
85 
86  private:
87   // Returns the op_id and output_num if the given local TensorHandle exists in
88   // remote_tensor_handle_map_.
89   Status GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
90                                const bool wait_until_ready, int64_t* op_id,
91                                int32* output_num)
92       TF_SHARED_LOCKS_REQUIRED(remote_tensor_handle_mu_);
93 
94   Status GetTensorHandleImpl(const RemoteTensorHandleInternal& remote_handle,
95                              tensorflow::TensorHandle** handle)
96       TF_SHARED_LOCKS_REQUIRED(remote_tensor_handle_mu_);
97 
98   Status GetMirroredResourceShape(
99       const RemoteTensorHandleInternal& remote_handle,
100       std::vector<DtypeAndPartialTensorShape>* handle);
101 
102   bool is_master_;
103 
104   using RemoteTensorHandleMap =
105       gtl::FlatMap<RemoteTensorHandleInternal, tensorflow::TensorHandle*,
106                    RemoteTensorHandleInternalHash,
107                    RemoteTensorHandleInternalEquals>;
108   using MirroredResourceShapeMap = gtl::FlatMap<
109       RemoteTensorHandleInternal, std::vector<DtypeAndPartialTensorShape>,
110       RemoteTensorHandleInternalHash, RemoteTensorHandleInternalEquals>;
111 
112   mutex remote_tensor_handle_mu_;
113   // This map maintains the TensorHandles that are required by remote workers
114   // in the cluster. Each map key is generated by the master, so it should be
115   // globally unique. This map owns references on the handles it contains.
116   RemoteTensorHandleMap remote_tensor_handle_map_
117       TF_GUARDED_BY(remote_tensor_handle_mu_);
118 
119   mutex mirrored_resource_shape_mu_;
120   // This map maintains the data types and shapes of resource variables required
121   // by remote workers in the cluster. Each map key is generated by the master,
122   // so it should be globally unique.
123   MirroredResourceShapeMap mirrored_resource_shape_map_
124       TF_GUARDED_BY(mirrored_resource_shape_mu_);
125 
126   EagerContext* parent_;  // not owned.
127 
128   mutex executor_map_mu_;
129   std::unordered_map<uint64, EagerExecutor> executor_map_
130       TF_GUARDED_BY(executor_map_mu_);
131 };
132 
133 }  // namespace eager
134 }  // namespace tensorflow
135 
136 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_MGR_H_
137