1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
16 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
17 
18 #include <memory>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/types/optional.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/eager/context.h"
25 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
28 
29 namespace tensorflow {
30 
31 class WorkerSession;
32 
33 namespace eager {
34 
35 // EagerClusterFunctionLibraryRuntime contains methods to Instantiate and Run
36 // functions across processes by making RPCs through eager service.
37 class EagerClusterFunctionLibraryRuntime
38     : public DistributedFunctionLibraryRuntime {
39  public:
EagerClusterFunctionLibraryRuntime(const uint64 context_id,EagerContext * ctx,DeviceMgr * remote_device_mgr)40   EagerClusterFunctionLibraryRuntime(const uint64 context_id, EagerContext* ctx,
41                                      DeviceMgr* remote_device_mgr)
42       : context_id_(context_id),
43         ctx_(ctx),
44         remote_device_mgr_(remote_device_mgr) {}
45 
~EagerClusterFunctionLibraryRuntime()46   ~EagerClusterFunctionLibraryRuntime() override{};
47 
48   // Register a partition (i.e., component function) of a multi-device function
49   // on the remote target specified in `options.target`. This should be
50   // triggered as part of instantiating a multi-device function in
51   // ProcessFunctionLibraryRuntime.
52   void Instantiate(const string& function_name,
53                    const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
54                    const FunctionLibraryRuntime::InstantiateOptions& options,
55                    FunctionLibraryRuntime::LocalHandle* handle,
56                    FunctionLibraryRuntime::DoneCallback done) override;
57 
58   // Execute the component function specified by `handle` on its instantiated
59   // remote target. This should be triggered as part of driving a multi-device
60   // function execution in ProcessFunctionLibraryRuntime. Running the component
61   // function remotely is purely asynchronous, and multiple component functions
62   // with the same remote target are not executed in any particular ordering.
63   // The main function side must wait for all component functions to finish
64   // (i.e., the done callbacks triggered) before finishing its execution.
65   void Run(const FunctionLibraryRuntime::Options& opts,
66            FunctionLibraryRuntime::LocalHandle handle,
67            gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
68            FunctionLibraryRuntime::DoneCallback done) override;
69 
70   // The component function inputs `args` and outputs `rets` may refer to remote
71   // tensors on a remote device, which will be lazily resolved remotely where
72   // the inputs/outputs are actually consumed.
73   void Run(const FunctionLibraryRuntime::Options& opts,
74            FunctionLibraryRuntime::LocalHandle handle,
75            gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets,
76            FunctionLibraryRuntime::DoneCallback done) override;
77 
78   void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
79                FunctionLibraryRuntime::DoneCallback done) override;
80 
remote_device_mgr()81   DeviceMgr* remote_device_mgr() const override { return remote_device_mgr_; }
82 
83  private:
84   const uint64 context_id_;
85   EagerContext* ctx_;
86   DeviceMgr* remote_device_mgr_;  // not owned.
87 
88   struct FunctionData {
89     const string target;
90     const absl::optional<std::vector<int>> ret_indices;
91     core::RefCountPtr<EagerClient> eager_client;
92     std::unique_ptr<EagerOperation> op;
93 
FunctionDataFunctionData94     FunctionData(const string& target,
95                  const absl::optional<std::vector<int>>& ret_indices,
96                  EagerClient* eager_client, std::unique_ptr<EagerOperation> op)
97         : target(target),
98           ret_indices(ret_indices),
99           eager_client(core::RefCountPtr<EagerClient>(eager_client)),
100           op(std::move(op)) {
101       eager_client->Ref();
102     }
103   };
104 
105   mutable mutex mu_;
106   std::vector<FunctionData> function_data_ TF_GUARDED_BY(mu_);
107 };
108 
109 DistributedFunctionLibraryRuntime* CreateClusterFLR(
110     const uint64 context_id, EagerContext* ctx, WorkerSession* worker_session);
111 
112 }  // namespace eager
113 }  // namespace tensorflow
114 
115 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
116