1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ 16 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ 17 18 #include <memory> 19 #include <utility> 20 #include <vector> 21 22 #include "absl/types/optional.h" 23 #include "tensorflow/core/common_runtime/device_mgr.h" 24 #include "tensorflow/core/common_runtime/eager/context.h" 25 #include "tensorflow/core/common_runtime/eager/eager_operation.h" 26 #include "tensorflow/core/framework/function.h" 27 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" 28 29 namespace tensorflow { 30 31 class WorkerSession; 32 33 namespace eager { 34 35 // EagerClusterFunctionLibraryRuntime contains methods to Instantiate and Run 36 // functions across processes by making RPCs through eager service. 37 class EagerClusterFunctionLibraryRuntime 38 : public DistributedFunctionLibraryRuntime { 39 public: EagerClusterFunctionLibraryRuntime(const uint64 context_id,EagerContext * ctx,DeviceMgr * remote_device_mgr)40 EagerClusterFunctionLibraryRuntime(const uint64 context_id, EagerContext* ctx, 41 DeviceMgr* remote_device_mgr) 42 : context_id_(context_id), 43 ctx_(ctx), 44 remote_device_mgr_(remote_device_mgr) {} 45 ~EagerClusterFunctionLibraryRuntime()46 ~EagerClusterFunctionLibraryRuntime() override{}; 47 48 // Register a partition (i.e., component function) of a multi-device function 49 // on the remote target specified in `options.target`. This should be 50 // triggered as part of instantiating a multi-device function in 51 // ProcessFunctionLibraryRuntime. 52 void Instantiate(const string& function_name, 53 const FunctionLibraryDefinition& lib_def, AttrSlice attrs, 54 const FunctionLibraryRuntime::InstantiateOptions& options, 55 FunctionLibraryRuntime::LocalHandle* handle, 56 FunctionLibraryRuntime::DoneCallback done) override; 57 58 // Execute the component function specified by `handle` on its instantiated 59 // remote target. This should be triggered as part of driving a multi-device 60 // function execution in ProcessFunctionLibraryRuntime. Running the component 61 // function remotely is purely asynchronous, and multiple component functions 62 // with the same remote target are not executed in any particular ordering. 63 // The main function side must wait for all component functions to finish 64 // (i.e., the done callbacks triggered) before finishing its execution. 65 void Run(const FunctionLibraryRuntime::Options& opts, 66 FunctionLibraryRuntime::LocalHandle handle, 67 gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets, 68 FunctionLibraryRuntime::DoneCallback done) override; 69 70 // The component function inputs `args` and outputs `rets` may refer to remote 71 // tensors on a remote device, which will be lazily resolved remotely where 72 // the inputs/outputs are actually consumed. 73 void Run(const FunctionLibraryRuntime::Options& opts, 74 FunctionLibraryRuntime::LocalHandle handle, 75 gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets, 76 FunctionLibraryRuntime::DoneCallback done) override; 77 78 void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle, 79 FunctionLibraryRuntime::DoneCallback done) override; 80 remote_device_mgr()81 DeviceMgr* remote_device_mgr() const override { return remote_device_mgr_; } 82 83 private: 84 const uint64 context_id_; 85 EagerContext* ctx_; 86 DeviceMgr* remote_device_mgr_; // not owned. 87 88 struct FunctionData { 89 const string target; 90 const absl::optional<std::vector<int>> ret_indices; 91 core::RefCountPtr<EagerClient> eager_client; 92 std::unique_ptr<EagerOperation> op; 93 FunctionDataFunctionData94 FunctionData(const string& target, 95 const absl::optional<std::vector<int>>& ret_indices, 96 EagerClient* eager_client, std::unique_ptr<EagerOperation> op) 97 : target(target), 98 ret_indices(ret_indices), 99 eager_client(core::RefCountPtr<EagerClient>(eager_client)), 100 op(std::move(op)) { 101 eager_client->Ref(); 102 } 103 }; 104 105 mutable mutex mu_; 106 std::vector<FunctionData> function_data_ TF_GUARDED_BY(mu_); 107 }; 108 109 DistributedFunctionLibraryRuntime* CreateClusterFLR( 110 const uint64 context_id, EagerContext* ctx, WorkerSession* worker_session); 111 112 } // namespace eager 113 } // namespace tensorflow 114 115 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ 116