1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_ 17 18 #include "tensorflow/core/framework/collective.h" 19 #include "tensorflow/core/framework/device_attributes.pb.h" 20 #include "tensorflow/core/lib/gtl/flatmap.h" 21 22 namespace tensorflow { 23 24 // Mock objects that can't actually execute a Collective, but satisfy 25 // general infrastructure expectations within tests that don't require 26 // full functionality. 27 28 class TestCollectiveExecutor : public CollectiveExecutor { 29 public: 30 explicit TestCollectiveExecutor(CollectiveExecutorMgrInterface* cem, 31 CollectiveRemoteAccess* rma = nullptr) CollectiveExecutor(cem)32 : CollectiveExecutor(cem), rma_(rma) {} 33 RunClosure(std::function<void ()> fn)34 void RunClosure(std::function<void()> fn) override { fn(); } 35 remote_access()36 CollectiveRemoteAccess* remote_access() override { return rma_; } 37 38 private: 39 CollectiveRemoteAccess* rma_; 40 }; 41 42 class TestParamResolver : public ParamResolverInterface { CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)43 void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp, 44 CancellationManager* cancel_mgr, 45 const StatusCallback& done) override { 46 done(errors::Internal("Unimplemented")); 47 } 48 CompleteGroupAsync(const DeviceAttributes & device,CollGroupParams * group_params,CancellationManager * cancel_mgr,const StatusCallback & done)49 void CompleteGroupAsync(const DeviceAttributes& device, 50 CollGroupParams* group_params, 51 CancellationManager* cancel_mgr, 52 const StatusCallback& done) override { 53 done(errors::Internal("Unimplemented")); 54 } 55 CompleteInstanceAsync(const CompleteInstanceRequest * request,CompleteInstanceResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)56 void CompleteInstanceAsync(const CompleteInstanceRequest* request, 57 CompleteInstanceResponse* response, 58 CancellationManager* cancel_mgr, 59 const StatusCallback& done) override { 60 done(errors::Internal("Unimplemented")); 61 } 62 LookupGroup(int32_t group_key,CollGroupParams * group)63 Status LookupGroup(int32_t group_key, CollGroupParams* group) override { 64 return errors::Internal("Unimplemented"); 65 } 66 StartAbort(const Status & s)67 void StartAbort(const Status& s) override {} 68 }; 69 70 class TestCollectiveExecutorMgr : public CollectiveExecutorMgrInterface { 71 public: TestCollectiveExecutorMgr(ParamResolverInterface * param_resolver,CollectiveRemoteAccess * rma)72 explicit TestCollectiveExecutorMgr(ParamResolverInterface* param_resolver, 73 CollectiveRemoteAccess* rma) 74 : param_resolver_(param_resolver), rma_(rma) {} 75 TestCollectiveExecutorMgr()76 TestCollectiveExecutorMgr() : param_resolver_(nullptr), rma_(nullptr) {} 77 ~TestCollectiveExecutorMgr()78 ~TestCollectiveExecutorMgr() override { 79 for (auto& iter : table_) { 80 iter.second->Unref(); 81 } 82 } 83 FindOrCreate(int64_t step_id)84 CollectiveExecutor* FindOrCreate(int64_t step_id) override { 85 mutex_lock l(mu_); 86 CollectiveExecutor* ce = nullptr; 87 auto iter = table_.find(step_id); 88 if (iter != table_.end()) { 89 ce = iter->second; 90 } else { 91 ce = new TestCollectiveExecutor(this, rma_); 92 table_[step_id] = ce; 93 } 94 ce->Ref(); 95 return ce; 96 } 97 Cleanup(int64_t step_id)98 void Cleanup(int64_t step_id) override { 99 mutex_lock l(mu_); 100 auto iter = table_.find(step_id); 101 if (iter != table_.end()) { 102 iter->second->Unref(); 103 table_.erase(iter); 104 } 105 } 106 GetParamResolver()107 ParamResolverInterface* GetParamResolver() const override { 108 return param_resolver_; 109 } 110 GetDeviceResolver()111 DeviceResolverInterface* GetDeviceResolver() const override { 112 LOG(FATAL); 113 return nullptr; 114 } 115 GetNcclCommunicator()116 NcclCommunicatorInterface* GetNcclCommunicator() const override { 117 return nullptr; 118 } 119 GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,const StatusCallback & done)120 void GetStepSequenceAsync(const GetStepSequenceRequest* request, 121 GetStepSequenceResponse* response, 122 const StatusCallback& done) override { 123 done(errors::Internal("unimplemented")); 124 } 125 RefreshStepIdSequenceAsync(int64_t graph_key,const StatusCallback & done)126 void RefreshStepIdSequenceAsync(int64_t graph_key, 127 const StatusCallback& done) override { 128 done(errors::Internal("unimplemented")); 129 } 130 NextStepId(int64_t graph_key)131 int64_t NextStepId(int64_t graph_key) override { 132 return CollectiveExecutor::kInvalidId; 133 } 134 RetireStepId(int64_t graph_key,int64_t step_id)135 void RetireStepId(int64_t graph_key, int64_t step_id) override {} 136 137 protected: 138 mutex mu_; 139 gtl::FlatMap<int64_t, CollectiveExecutor*> table_ TF_GUARDED_BY(mu_); 140 ParamResolverInterface* param_resolver_; 141 CollectiveRemoteAccess* rma_; 142 }; 143 144 } // namespace tensorflow 145 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_ 146