1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TEST_UTILS_H_ 16 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TEST_UTILS_H_ 17 18 #include <unordered_map> 19 #include "tensorflow/core/distributed_runtime/worker_cache.h" 20 #include "tensorflow/core/distributed_runtime/worker_interface.h" 21 #include "tensorflow/core/util/device_name_utils.h" 22 23 namespace tensorflow { 24 25 // Some utilities for testing distributed-mode components in a single process 26 // without RPCs. 27 28 // Implements the worker interface with methods that just respond with 29 // "unimplemented" status. Override just the methods needed for 30 // testing. 31 class TestWorkerInterface : public WorkerInterface { 32 public: GetStatusAsync(CallOptions * opts,const GetStatusRequest * request,GetStatusResponse * response,bool fail_fast,StatusCallback done)33 void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request, 34 GetStatusResponse* response, bool fail_fast, 35 StatusCallback done) override { 36 done(errors::Unimplemented("GetStatusAsync")); 37 } 38 CreateWorkerSessionAsync(const CreateWorkerSessionRequest * request,CreateWorkerSessionResponse * response,StatusCallback done)39 void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request, 40 CreateWorkerSessionResponse* response, 41 StatusCallback done) override { 42 done(errors::Unimplemented("CreateWorkerSessionAsync")); 43 } 44 DeleteWorkerSessionAsync(CallOptions * opts,const DeleteWorkerSessionRequest * request,DeleteWorkerSessionResponse * response,StatusCallback done)45 void DeleteWorkerSessionAsync(CallOptions* opts, 46 const DeleteWorkerSessionRequest* request, 47 DeleteWorkerSessionResponse* response, 48 StatusCallback done) override { 49 done(errors::Unimplemented("DeleteWorkerSessionAsync")); 50 } 51 RegisterGraphAsync(const RegisterGraphRequest * request,RegisterGraphResponse * response,StatusCallback done)52 void RegisterGraphAsync(const RegisterGraphRequest* request, 53 RegisterGraphResponse* response, 54 StatusCallback done) override { 55 done(errors::Unimplemented("RegisterGraphAsync")); 56 } 57 DeregisterGraphAsync(const DeregisterGraphRequest * request,DeregisterGraphResponse * response,StatusCallback done)58 void DeregisterGraphAsync(const DeregisterGraphRequest* request, 59 DeregisterGraphResponse* response, 60 StatusCallback done) override { 61 done(errors::Unimplemented("DeregisterGraphAsync")); 62 } 63 RunGraphAsync(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)64 void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request, 65 MutableRunGraphResponseWrapper* response, 66 StatusCallback done) override { 67 done(errors::Unimplemented("RunGraphAsync")); 68 } 69 CleanupGraphAsync(const CleanupGraphRequest * request,CleanupGraphResponse * response,StatusCallback done)70 void CleanupGraphAsync(const CleanupGraphRequest* request, 71 CleanupGraphResponse* response, 72 StatusCallback done) override { 73 done(errors::Unimplemented("CleanupGraphAsync")); 74 } 75 CleanupAllAsync(const CleanupAllRequest * request,CleanupAllResponse * response,StatusCallback done)76 void CleanupAllAsync(const CleanupAllRequest* request, 77 CleanupAllResponse* response, 78 StatusCallback done) override { 79 done(errors::Unimplemented("CleanupAllAsync")); 80 } 81 RecvTensorAsync(CallOptions * opts,const RecvTensorRequest * request,TensorResponse * response,StatusCallback done)82 void RecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request, 83 TensorResponse* response, StatusCallback done) override { 84 done(errors::Unimplemented("RecvTensorAsync")); 85 } 86 LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)87 void LoggingAsync(const LoggingRequest* request, LoggingResponse* response, 88 StatusCallback done) override { 89 done(errors::Unimplemented("LoggingAsync")); 90 } 91 TracingAsync(const TracingRequest * request,TracingResponse * response,StatusCallback done)92 void TracingAsync(const TracingRequest* request, TracingResponse* response, 93 StatusCallback done) override { 94 done(errors::Unimplemented("TracingAsync")); 95 } 96 RecvBufAsync(CallOptions * opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)97 void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, 98 RecvBufResponse* response, StatusCallback done) override { 99 done(errors::Unimplemented("RecvBufAsync")); 100 } 101 CompleteGroupAsync(CallOptions * opts,const CompleteGroupRequest * request,CompleteGroupResponse * response,StatusCallback done)102 void CompleteGroupAsync(CallOptions* opts, 103 const CompleteGroupRequest* request, 104 CompleteGroupResponse* response, 105 StatusCallback done) override { 106 done(errors::Unimplemented("CompleteGroupAsync")); 107 } 108 CompleteInstanceAsync(CallOptions * ops,const CompleteInstanceRequest * request,CompleteInstanceResponse * response,StatusCallback done)109 void CompleteInstanceAsync(CallOptions* ops, 110 const CompleteInstanceRequest* request, 111 CompleteInstanceResponse* response, 112 StatusCallback done) override { 113 done(errors::Unimplemented("CompleteInstanceAsync")); 114 } 115 GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,StatusCallback done)116 void GetStepSequenceAsync(const GetStepSequenceRequest* request, 117 GetStepSequenceResponse* response, 118 StatusCallback done) override { 119 done(errors::Unimplemented("GetStepSequenceAsync")); 120 } 121 }; 122 123 class TestWorkerCache : public WorkerCacheInterface { 124 public: ~TestWorkerCache()125 virtual ~TestWorkerCache() {} 126 AddWorker(const string & target,WorkerInterface * wi)127 void AddWorker(const string& target, WorkerInterface* wi) { 128 workers_[target] = wi; 129 } 130 AddDevice(const string & device_name,const DeviceLocality & dev_loc)131 void AddDevice(const string& device_name, const DeviceLocality& dev_loc) { 132 localities_[device_name] = dev_loc; 133 } 134 ListWorkers(std::vector<string> * workers)135 void ListWorkers(std::vector<string>* workers) const override { 136 workers->clear(); 137 for (auto it : workers_) { 138 workers->push_back(it.first); 139 } 140 } 141 ListWorkersInJob(const string & job_name,std::vector<string> * workers)142 void ListWorkersInJob(const string& job_name, 143 std::vector<string>* workers) const override { 144 workers->clear(); 145 for (auto it : workers_) { 146 DeviceNameUtils::ParsedName device_name; 147 CHECK(DeviceNameUtils::ParseFullName(it.first, &device_name)); 148 CHECK(device_name.has_job); 149 if (job_name == device_name.job) { 150 workers->push_back(it.first); 151 } 152 } 153 } 154 GetOrCreateWorker(const string & target)155 WorkerInterface* GetOrCreateWorker(const string& target) override { 156 auto it = workers_.find(target); 157 if (it != workers_.end()) { 158 return it->second; 159 } 160 return nullptr; 161 } 162 ReleaseWorker(const string & target,WorkerInterface * worker)163 void ReleaseWorker(const string& target, WorkerInterface* worker) override {} 164 GetEagerClientCache(std::unique_ptr<eager::EagerClientCache> * eager_client_cache)165 Status GetEagerClientCache( 166 std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override { 167 return errors::Unimplemented("Unimplemented."); 168 } 169 GetCoordinationClientCache(std::unique_ptr<CoordinationClientCache> * coord_client_cache)170 Status GetCoordinationClientCache( 171 std::unique_ptr<CoordinationClientCache>* coord_client_cache) override { 172 return errors::Unimplemented("Unimplemented."); 173 } 174 GetDeviceLocalityNonBlocking(const string & device,DeviceLocality * locality)175 bool GetDeviceLocalityNonBlocking(const string& device, 176 DeviceLocality* locality) override { 177 auto it = localities_.find(device); 178 if (it != localities_.end()) { 179 *locality = it->second; 180 return true; 181 } 182 return false; 183 } 184 GetDeviceLocalityAsync(const string & device,DeviceLocality * locality,StatusCallback done)185 void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality, 186 StatusCallback done) override { 187 auto it = localities_.find(device); 188 if (it != localities_.end()) { 189 *locality = it->second; 190 done(OkStatus()); 191 return; 192 } 193 done(errors::Internal("Device not found: ", device)); 194 } 195 196 protected: 197 std::unordered_map<string, WorkerInterface*> workers_; 198 std::unordered_map<string, DeviceLocality> localities_; 199 }; 200 201 } // namespace tensorflow 202 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TEST_UTILS_H_ 203