xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/test_collective_executor_mgr.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_
17 
18 #include "tensorflow/core/framework/collective.h"
19 #include "tensorflow/core/framework/device_attributes.pb.h"
20 #include "tensorflow/core/lib/gtl/flatmap.h"
21 
22 namespace tensorflow {
23 
24 // Mock objects that can't actually execute a Collective, but satisfy
25 // general infrastructure expectations within tests that don't require
26 // full functionality.
27 
28 class TestCollectiveExecutor : public CollectiveExecutor {
29  public:
30   explicit TestCollectiveExecutor(CollectiveExecutorMgrInterface* cem,
31                                   CollectiveRemoteAccess* rma = nullptr)
CollectiveExecutor(cem)32       : CollectiveExecutor(cem), rma_(rma) {}
33 
RunClosure(std::function<void ()> fn)34   void RunClosure(std::function<void()> fn) override { fn(); }
35 
remote_access()36   CollectiveRemoteAccess* remote_access() override { return rma_; }
37 
38  private:
39   CollectiveRemoteAccess* rma_;
40 };
41 
42 class TestParamResolver : public ParamResolverInterface {
CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)43   void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
44                            CancellationManager* cancel_mgr,
45                            const StatusCallback& done) override {
46     done(errors::Internal("Unimplemented"));
47   }
48 
CompleteGroupAsync(const DeviceAttributes & device,CollGroupParams * group_params,CancellationManager * cancel_mgr,const StatusCallback & done)49   void CompleteGroupAsync(const DeviceAttributes& device,
50                           CollGroupParams* group_params,
51                           CancellationManager* cancel_mgr,
52                           const StatusCallback& done) override {
53     done(errors::Internal("Unimplemented"));
54   }
55 
CompleteInstanceAsync(const CompleteInstanceRequest * request,CompleteInstanceResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)56   void CompleteInstanceAsync(const CompleteInstanceRequest* request,
57                              CompleteInstanceResponse* response,
58                              CancellationManager* cancel_mgr,
59                              const StatusCallback& done) override {
60     done(errors::Internal("Unimplemented"));
61   }
62 
LookupGroup(int32_t group_key,CollGroupParams * group)63   Status LookupGroup(int32_t group_key, CollGroupParams* group) override {
64     return errors::Internal("Unimplemented");
65   }
66 
StartAbort(const Status & s)67   void StartAbort(const Status& s) override {}
68 };
69 
70 class TestCollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
71  public:
TestCollectiveExecutorMgr(ParamResolverInterface * param_resolver,CollectiveRemoteAccess * rma)72   explicit TestCollectiveExecutorMgr(ParamResolverInterface* param_resolver,
73                                      CollectiveRemoteAccess* rma)
74       : param_resolver_(param_resolver), rma_(rma) {}
75 
TestCollectiveExecutorMgr()76   TestCollectiveExecutorMgr() : param_resolver_(nullptr), rma_(nullptr) {}
77 
~TestCollectiveExecutorMgr()78   ~TestCollectiveExecutorMgr() override {
79     for (auto& iter : table_) {
80       iter.second->Unref();
81     }
82   }
83 
FindOrCreate(int64_t step_id)84   CollectiveExecutor* FindOrCreate(int64_t step_id) override {
85     mutex_lock l(mu_);
86     CollectiveExecutor* ce = nullptr;
87     auto iter = table_.find(step_id);
88     if (iter != table_.end()) {
89       ce = iter->second;
90     } else {
91       ce = new TestCollectiveExecutor(this, rma_);
92       table_[step_id] = ce;
93     }
94     ce->Ref();
95     return ce;
96   }
97 
Cleanup(int64_t step_id)98   void Cleanup(int64_t step_id) override {
99     mutex_lock l(mu_);
100     auto iter = table_.find(step_id);
101     if (iter != table_.end()) {
102       iter->second->Unref();
103       table_.erase(iter);
104     }
105   }
106 
GetParamResolver()107   ParamResolverInterface* GetParamResolver() const override {
108     return param_resolver_;
109   }
110 
GetDeviceResolver()111   DeviceResolverInterface* GetDeviceResolver() const override {
112     LOG(FATAL);
113     return nullptr;
114   }
115 
GetNcclCommunicator()116   NcclCommunicatorInterface* GetNcclCommunicator() const override {
117     return nullptr;
118   }
119 
GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,const StatusCallback & done)120   void GetStepSequenceAsync(const GetStepSequenceRequest* request,
121                             GetStepSequenceResponse* response,
122                             const StatusCallback& done) override {
123     done(errors::Internal("unimplemented"));
124   }
125 
RefreshStepIdSequenceAsync(int64_t graph_key,const StatusCallback & done)126   void RefreshStepIdSequenceAsync(int64_t graph_key,
127                                   const StatusCallback& done) override {
128     done(errors::Internal("unimplemented"));
129   }
130 
NextStepId(int64_t graph_key)131   int64_t NextStepId(int64_t graph_key) override {
132     return CollectiveExecutor::kInvalidId;
133   }
134 
RetireStepId(int64_t graph_key,int64_t step_id)135   void RetireStepId(int64_t graph_key, int64_t step_id) override {}
136 
137  protected:
138   mutex mu_;
139   gtl::FlatMap<int64_t, CollectiveExecutor*> table_ TF_GUARDED_BY(mu_);
140   ParamResolverInterface* param_resolver_;
141   CollectiveRemoteAccess* rma_;
142 };
143 
144 }  // namespace tensorflow
145 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_
146